mindspore 2.0.0a0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
- mindspore/__init__.py +4 -2
- mindspore/_akg/akg/composite/build_module.py +11 -0
- mindspore/_akg/akg/config/repository_cuda.json +11 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +52 -57
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/config/super_bar_config.json +512 -0
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libicudata.so.69 +0 -0
- mindspore/lib/libicui18n.so.69 +0 -0
- mindspore/lib/libicuuc.so.69 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so → libmindspore_ascend.so.2} +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/scipy/linalg.py +10 -114
- mindspore/scipy/ops.py +2 -2
- mindspore/scipy/ops_wrapper.py +1 -1
- mindspore/scipy/optimize/_bfgs.py +1 -1
- mindspore/scipy/optimize/_lagrange.py +200 -0
- mindspore/scipy/optimize/line_search.py +3 -2
- mindspore/scipy/optimize/minimize.py +41 -2
- mindspore/scipy/sparse/__init__.py +2 -2
- mindspore/scipy/sparse/linalg.py +4 -464
- mindspore/scipy/utils.py +1 -1
- mindspore/scipy/utils_const.py +7 -1
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +648 -574
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/rewrite/node.py
CHANGED
|
@@ -20,7 +20,7 @@ import inspect
|
|
|
20
20
|
from mindspore.nn import Cell
|
|
21
21
|
from mindspore.ops import Primitive
|
|
22
22
|
from mindspore import log as logger
|
|
23
|
-
from ..
|
|
23
|
+
from .. import _checkparam as Validator
|
|
24
24
|
from .ast_helpers import AstModifier
|
|
25
25
|
from .api.scoped_value import ScopedValue, ValueType
|
|
26
26
|
from .api.node_type import NodeType
|
|
@@ -222,6 +222,32 @@ class Node:
|
|
|
222
222
|
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {},
|
|
223
223
|
name, None)
|
|
224
224
|
|
|
225
|
+
@classmethod
|
|
226
|
+
def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
|
|
227
|
+
op_type: ScopedValue, args: [ScopedValue],
|
|
228
|
+
ops: {str: list}, name: str = ""):
|
|
229
|
+
"""
|
|
230
|
+
Class method of Node. Instantiate an instance of node whose type is `MathOps` .
|
|
231
|
+
A mathops node is used to represent a node with mathematical operations, such as
|
|
232
|
+
`y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of
|
|
236
|
+
node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and
|
|
237
|
+
ast.Compare.
|
|
238
|
+
targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`.
|
|
239
|
+
See detail in docstring of Node class.
|
|
240
|
+
op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
|
|
241
|
+
args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
|
|
242
|
+
sequentially in the list.
|
|
243
|
+
ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
|
|
244
|
+
saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
|
|
245
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
|
246
|
+
Name of node also used as field name in network class. The format of mathops node name
|
|
247
|
+
is 'AstNodeName_AstOpName_n'.
|
|
248
|
+
"""
|
|
249
|
+
return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
|
|
250
|
+
|
|
225
251
|
@staticmethod
|
|
226
252
|
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
227
253
|
func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
|
|
@@ -624,7 +650,8 @@ class Node:
|
|
|
624
650
|
"""
|
|
625
651
|
self._targets = targets
|
|
626
652
|
if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
|
|
627
|
-
NodeType.Tree, NodeType.CallFunction
|
|
653
|
+
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer,
|
|
654
|
+
NodeType.MathOps):
|
|
628
655
|
self._sync_assign_targets_to_ast()
|
|
629
656
|
|
|
630
657
|
def get_func(self) -> ScopedValue:
|
|
@@ -721,12 +748,12 @@ class Node:
|
|
|
721
748
|
ValueError: If `node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
|
|
722
749
|
"""
|
|
723
750
|
Validator.check_value_type("node", node, [Node], "Node")
|
|
724
|
-
Validator.check_int_range(arg_idx, 0, self._args_num,
|
|
751
|
+
Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
|
|
725
752
|
if out_idx is None:
|
|
726
753
|
if len(node._targets) != 1:
|
|
727
754
|
raise RuntimeError("node should has one output when out_idx is not provided")
|
|
728
755
|
out_idx = 0
|
|
729
|
-
Validator.check_int_range(out_idx, 0, len(node._targets),
|
|
756
|
+
Validator.check_int_range(out_idx, 0, len(node._targets), Validator.INC_LEFT, "arg_idx")
|
|
730
757
|
new_arg = node._targets[out_idx]
|
|
731
758
|
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
|
|
732
759
|
self._sync_arg()
|
|
@@ -743,7 +770,7 @@ class Node:
|
|
|
743
770
|
Raises:
|
|
744
771
|
ValueError: If `index` is out of range.
|
|
745
772
|
"""
|
|
746
|
-
Validator.check_int_range(index, 0, self._args_num,
|
|
773
|
+
Validator.check_int_range(index, 0, self._args_num, Validator.INC_LEFT, "index")
|
|
747
774
|
Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
|
|
748
775
|
if isinstance(arg, str):
|
|
749
776
|
arg = ScopedValue.create_naming_value(arg)
|
|
@@ -763,7 +790,7 @@ class Node:
|
|
|
763
790
|
Raises:
|
|
764
791
|
TypeError: Element of new argument is not an instance of ScopedValue.
|
|
765
792
|
"""
|
|
766
|
-
Validator.check_int_range(len(args), 0, self._args_num,
|
|
793
|
+
Validator.check_int_range(len(args), 0, self._args_num, Validator.INC_LEFT, "Length of args")
|
|
767
794
|
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
768
795
|
for arg_index, arg in enumerate(args):
|
|
769
796
|
if not isinstance(arg, ScopedValue):
|
|
@@ -783,7 +810,7 @@ class Node:
|
|
|
783
810
|
TypeError: Value of new argument is not an instance of ScopedValue.
|
|
784
811
|
RuntimeError: Length of new arguments is not equal to length of old arguments.
|
|
785
812
|
"""
|
|
786
|
-
Validator.check_int_range(len(kwargs), 0, self._kwargs_num,
|
|
813
|
+
Validator.check_int_range(len(kwargs), 0, self._kwargs_num, Validator.INC_LEFT, "Length of kwargs")
|
|
787
814
|
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
788
815
|
for key, arg in kwargs.items():
|
|
789
816
|
if key not in self._normalized_args.keys() or key not in self._normalized_args_keys:
|
|
@@ -1099,7 +1126,7 @@ class Node:
|
|
|
1099
1126
|
elt.id = scoped_value.value
|
|
1100
1127
|
elif isinstance(elt, ast.Attribute) and isinstance(elt.value, ast.Name):
|
|
1101
1128
|
elt.value.id = scoped_value.scope
|
|
1102
|
-
elt.
|
|
1129
|
+
elt.attr = scoped_value.value
|
|
1103
1130
|
else:
|
|
1104
1131
|
raise RuntimeError("Only support constant or symbol in tuple now")
|
|
1105
1132
|
else:
|
|
@@ -1133,14 +1160,50 @@ class Node:
|
|
|
1133
1160
|
raise RuntimeError("Unsupported return value type: ", return_value_ast)
|
|
1134
1161
|
ast.fix_missing_locations(return_ast)
|
|
1135
1162
|
|
|
1163
|
+
def _sync_mathops_node_args_to_ast(self):
|
|
1164
|
+
"""
|
|
1165
|
+
Sync values from self._normalized_args to the ast node for mathematical operations.
|
|
1166
|
+
"""
|
|
1167
|
+
if self._ast_node is None:
|
|
1168
|
+
return
|
|
1169
|
+
if not isinstance(self._ast_node, ast.Assign):
|
|
1170
|
+
raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}")
|
|
1171
|
+
mathops_node = self._ast_node.value
|
|
1172
|
+
if isinstance(mathops_node, ast.BinOp):
|
|
1173
|
+
left = mathops_node.left
|
|
1174
|
+
right = mathops_node.right
|
|
1175
|
+
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
|
|
1176
|
+
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[1]), right)
|
|
1177
|
+
elif isinstance(mathops_node, ast.UnaryOp):
|
|
1178
|
+
operand = mathops_node.operand
|
|
1179
|
+
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
|
|
1180
|
+
elif isinstance(mathops_node, ast.BoolOp):
|
|
1181
|
+
values = mathops_node.values
|
|
1182
|
+
for arg_index in range(self._args_num):
|
|
1183
|
+
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
|
1184
|
+
AstModifier.update_arg_value(arg_value, values[arg_index])
|
|
1185
|
+
elif isinstance(mathops_node, ast.Compare):
|
|
1186
|
+
left = mathops_node.left
|
|
1187
|
+
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
|
|
1188
|
+
comparators = mathops_node.comparators
|
|
1189
|
+
for arg_index in range(1, self._args_num):
|
|
1190
|
+
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
|
1191
|
+
AstModifier.update_arg_value(arg_value, comparators[arg_index - 1])
|
|
1192
|
+
else:
|
|
1193
|
+
raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
|
|
1194
|
+
"ast.BoolOp, ast.Compare), but got ", type(mathops_node))
|
|
1195
|
+
|
|
1136
1196
|
def _sync_arg(self):
|
|
1137
1197
|
"""Sync _normalized_args to corresponding ast node when updated."""
|
|
1138
|
-
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree
|
|
1198
|
+
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,\
|
|
1199
|
+
NodeType.CellContainer, NodeType.CallFunction):
|
|
1139
1200
|
self._sync_call_cell_args_to_ast()
|
|
1140
1201
|
elif self._node_type == NodeType.Output:
|
|
1141
1202
|
self._sync_return_node_to_ast()
|
|
1142
1203
|
elif self._node_type == NodeType.CallMethod:
|
|
1143
1204
|
self._sync_call_method_args_to_ast()
|
|
1205
|
+
elif self._node_type == NodeType.MathOps:
|
|
1206
|
+
self._sync_mathops_node_args_to_ast()
|
|
1144
1207
|
|
|
1145
1208
|
|
|
1146
1209
|
class TreeNode(Node):
|
|
@@ -1188,8 +1251,6 @@ class TreeNode(Node):
|
|
|
1188
1251
|
instance: Object in network corresponding to this node.
|
|
1189
1252
|
"""
|
|
1190
1253
|
|
|
1191
|
-
if not isinstance(instance, Cell):
|
|
1192
|
-
raise ValueError("Argument instance should be a Cell: ", type(instance))
|
|
1193
1254
|
non_custom_args = Node._handle_custom_obj_in_args(args)
|
|
1194
1255
|
non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
|
|
1195
1256
|
new_targets = Node._handle_targets(targets)
|
|
@@ -1198,3 +1259,88 @@ class TreeNode(Node):
|
|
|
1198
1259
|
if ast_node is None:
|
|
1199
1260
|
ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
|
|
1200
1261
|
return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
class CellContainer(Node):
|
|
1265
|
+
""" Container for saving cell-objects node. """
|
|
1266
|
+
class _Visitor():
|
|
1267
|
+
""" A iterator of CellContainer nodes. """
|
|
1268
|
+
def __init__(self, cellcontainer):
|
|
1269
|
+
self._cellcontainer = cellcontainer
|
|
1270
|
+
|
|
1271
|
+
def __len__(self):
|
|
1272
|
+
""" Get the number of nodes. """
|
|
1273
|
+
return self._cellcontainer.node_count
|
|
1274
|
+
|
|
1275
|
+
def __iter__(self):
|
|
1276
|
+
"""Create an iterator over the CellContainer."""
|
|
1277
|
+
count = len(self._cellcontainer.node_list)
|
|
1278
|
+
i = 0
|
|
1279
|
+
while i < count:
|
|
1280
|
+
curr = self._cellcontainer.node_list[i]
|
|
1281
|
+
if curr.valid:
|
|
1282
|
+
yield curr
|
|
1283
|
+
i += 1
|
|
1284
|
+
|
|
1285
|
+
def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
|
|
1286
|
+
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
|
|
1287
|
+
"""Constructor of CellContainer.
|
|
1288
|
+
|
|
1289
|
+
Args:
|
|
1290
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
1291
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1292
|
+
func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
1293
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1294
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1295
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
1296
|
+
Name of node also used as field name in network class.
|
|
1297
|
+
instance: Object in network corresponding to this node.
|
|
1298
|
+
"""
|
|
1299
|
+
if isinstance(func, str):
|
|
1300
|
+
func = ScopedValue.create_naming_value(func)
|
|
1301
|
+
super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
|
|
1302
|
+
self._node_list = list()
|
|
1303
|
+
self._node_count = 0
|
|
1304
|
+
|
|
1305
|
+
@property
|
|
1306
|
+
def node_count(self):
|
|
1307
|
+
"""Number of nodes."""
|
|
1308
|
+
return len(self._node_list)
|
|
1309
|
+
|
|
1310
|
+
@property
|
|
1311
|
+
def node_list(self):
|
|
1312
|
+
""" Get node list. """
|
|
1313
|
+
return self._node_list
|
|
1314
|
+
|
|
1315
|
+
def append(self, node):
|
|
1316
|
+
""" Append new node to node list. """
|
|
1317
|
+
setattr(node, "container", self)
|
|
1318
|
+
setattr(node, "valid", True)
|
|
1319
|
+
node.set_belong_symbol_tree(self.get_belong_symbol_tree())
|
|
1320
|
+
self._node_list.append(node)
|
|
1321
|
+
# when creating a cell_container, node instance is already in SequentialCell cell_list
|
|
1322
|
+
# so here we need to write a if judgement
|
|
1323
|
+
if node.get_instance() not in self.get_instance().cell_list:
|
|
1324
|
+
self.get_instance().append(node.get_instance())
|
|
1325
|
+
|
|
1326
|
+
def erase(self, node):
|
|
1327
|
+
"""Erase node form container."""
|
|
1328
|
+
index_node = self.node_list.index(node)
|
|
1329
|
+
index_instance = self.get_instance().cell_list.index(node.get_instance())
|
|
1330
|
+
if index_node != index_instance:
|
|
1331
|
+
raise RuntimeError("In MindSpore Rewrite CellContainer, erasing a node raises index error!!!")
|
|
1332
|
+
setattr(node, "valid", False)
|
|
1333
|
+
del self.get_instance()[index_node]
|
|
1334
|
+
del self._node_list[index_node]
|
|
1335
|
+
|
|
1336
|
+
def insert(self, index, node):
|
|
1337
|
+
"""Insert node into container"""
|
|
1338
|
+
self.node_list.insert(index, node)
|
|
1339
|
+
setattr(node, "container", self)
|
|
1340
|
+
setattr(node, "valid", True)
|
|
1341
|
+
node.set_belong_symbol_tree(self.get_belong_symbol_tree())
|
|
1342
|
+
self.get_instance()._insert(index, node.get_instance())
|
|
1343
|
+
|
|
1344
|
+
def nodes(self):
|
|
1345
|
+
""" Return a iterator of node."""
|
|
1346
|
+
return self._Visitor(self)
|
|
@@ -15,21 +15,23 @@
|
|
|
15
15
|
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
16
16
|
from typing import Union
|
|
17
17
|
import ast
|
|
18
|
+
import sys
|
|
19
|
+
import inspect
|
|
18
20
|
import astunparse
|
|
19
21
|
|
|
20
22
|
from mindspore import log as logger
|
|
21
23
|
from mindspore._extends.parse.namespace import CellNamespace
|
|
22
|
-
from mindspore.nn import Cell
|
|
24
|
+
from mindspore.nn import Cell, SequentialCell
|
|
23
25
|
from mindspore.ops import operations as P
|
|
24
26
|
from mindspore.ops import Primitive
|
|
25
27
|
from mindspore.rewrite.parser_register import ParserRegister
|
|
26
28
|
from mindspore.rewrite.namespace import is_subtree, is_functional, get_functional
|
|
27
29
|
from mindspore.rewrite.symbol_tree import SymbolTree
|
|
28
|
-
from mindspore.rewrite.node import Node, TreeNode
|
|
30
|
+
from mindspore.rewrite.node import Node, TreeNode, CellContainer
|
|
29
31
|
from mindspore.rewrite.parser import Parser
|
|
30
32
|
from mindspore.rewrite.parser_register import reg_parser
|
|
31
33
|
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
|
|
32
|
-
from mindspore.rewrite.symbol_tree_builder import SymbolTreeBuilder
|
|
34
|
+
from mindspore.rewrite.symbol_tree_builder import SymbolTreeBuilder, FunctionSymbolTreeBuilder
|
|
33
35
|
from mindspore.rewrite.ast_helpers import AstReplacer, AstModifier
|
|
34
36
|
from mindspore.rewrite.common.event import Event
|
|
35
37
|
from ..common import error_str
|
|
@@ -65,7 +67,7 @@ class AssignParser(Parser):
|
|
|
65
67
|
tuple_elts = node.elts
|
|
66
68
|
tuple_values = []
|
|
67
69
|
for tuple_elt in tuple_elts:
|
|
68
|
-
if not isinstance(tuple_elt, (ast.Constant, ast.Name)):
|
|
70
|
+
if not isinstance(tuple_elt, (ast.Constant, ast.Name, ast.Attribute)):
|
|
69
71
|
raise RuntimeError(f"Only support ast.Constant or ast.Name as elts of ast.Tuple, "
|
|
70
72
|
f"but got ast type {type(tuple_elt).__name__}",
|
|
71
73
|
child_node=tuple_elt, father_node=node)
|
|
@@ -73,6 +75,8 @@ class AssignParser(Parser):
|
|
|
73
75
|
tuple_values.append(tuple_elt.value)
|
|
74
76
|
elif isinstance(tuple_elt, ast.Name):
|
|
75
77
|
tuple_values.append(tuple_elt.id)
|
|
78
|
+
elif isinstance(tuple_elt, ast.Attribute):
|
|
79
|
+
tuple_values.append("".join([tuple_elt.value.id, '.', tuple_elt.attr]))
|
|
76
80
|
return ScopedValue.create_variable_value(tuple(tuple_values))
|
|
77
81
|
|
|
78
82
|
@staticmethod
|
|
@@ -281,15 +285,15 @@ class AssignParser(Parser):
|
|
|
281
285
|
if len(body.targets) > 1:
|
|
282
286
|
raise NotImplementedError(error_str("not support multi-targets in assign now!", father_node=body))
|
|
283
287
|
target = body.targets[0]
|
|
284
|
-
if not isinstance(target, ast.Attribute) or not (target.value, ast.Name)
|
|
288
|
+
if not isinstance(target, ast.Attribute) or not isinstance(target.value, ast.Name):
|
|
285
289
|
continue
|
|
286
|
-
if target.attr != func_name:
|
|
290
|
+
if target.value.id != "self" or target.attr != func_name:
|
|
287
291
|
continue
|
|
288
292
|
changed = True
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
+
setattr(stree.get_origin_network(), func_name, sub_tree.get_origin_network())
|
|
294
|
+
args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
|
|
295
|
+
[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
296
|
+
ScopedValue(ValueType.StringValue, "", func_name)])
|
|
293
297
|
body.value = ast.Call(func=ast.Name(class_name, ast.Store()), args=[args_call], keywords=[])
|
|
294
298
|
break
|
|
295
299
|
return changed
|
|
@@ -308,6 +312,91 @@ class AssignParser(Parser):
|
|
|
308
312
|
call_args = [AssignParser._create_scopedvalue(arg) for arg in father_ast_node.value.args]
|
|
309
313
|
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, {})
|
|
310
314
|
|
|
315
|
+
@staticmethod
|
|
316
|
+
def _create_inputs_for_cell_container(father_ast_node) -> ['Node']:
|
|
317
|
+
"""Create inputs for cell container first node."""
|
|
318
|
+
call_ast_node = father_ast_node.value
|
|
319
|
+
if not isinstance(call_ast_node, ast.Call):
|
|
320
|
+
raise RuntimeError(error_str(f"when creating input node for cellcontainer, value of input father ast node"
|
|
321
|
+
"is not ast.Call!'", child_node=call_ast_node, father_node=father_ast_node))
|
|
322
|
+
first_node_inputs: ['Node'] = []
|
|
323
|
+
exist_param_name = []
|
|
324
|
+
for arg in call_ast_node.args:
|
|
325
|
+
if isinstance(arg, ast.Name):
|
|
326
|
+
param_name = arg.id
|
|
327
|
+
elif isinstance(arg, ast.arg):
|
|
328
|
+
param_name = arg.arg
|
|
329
|
+
else:
|
|
330
|
+
raise RuntimeError(error_str(f"only support ast.arg, ast.arg in arguments arg, but got "
|
|
331
|
+
f"'{type(arg).__name__}'", child_node=arg, father_node=call_ast_node))
|
|
332
|
+
if param_name in exist_param_name:
|
|
333
|
+
raise RuntimeError(error_str(f"Cellcontianer has duplicate input names", child_node=arg,
|
|
334
|
+
father_node=call_ast_node))
|
|
335
|
+
exist_param_name.append(param_name)
|
|
336
|
+
node = Node.create_input_node(arg, param_name, name=f"input_{param_name}")
|
|
337
|
+
first_node_inputs.append(node)
|
|
338
|
+
|
|
339
|
+
if call_ast_node.keywords:
|
|
340
|
+
raise RuntimeError(error_str(f"Not support keyword input for cellcontainer now.",
|
|
341
|
+
child_node=call_ast_node, father_node=father_ast_node))
|
|
342
|
+
|
|
343
|
+
return first_node_inputs
|
|
344
|
+
|
|
345
|
+
def _cell_container_process(self, ast_node, stree, targets, func, call_args, call_kwargs, op_name, container_obj):
|
|
346
|
+
""" parse cell container object."""
|
|
347
|
+
cell_container = CellContainer(ast_node, targets, func, call_args, call_kwargs, op_name, container_obj)
|
|
348
|
+
cell_container.set_belong_symbol_tree(stree)
|
|
349
|
+
first_node_inputs = AssignParser._create_inputs_for_cell_container(ast_node)
|
|
350
|
+
for i, cell in enumerate(container_obj):
|
|
351
|
+
is_sub_tree = is_subtree(type(cell).__name__)
|
|
352
|
+
if is_sub_tree:
|
|
353
|
+
stb = SymbolTreeBuilder(cell)
|
|
354
|
+
new_stree = stb.build()
|
|
355
|
+
replacer = AstReplacer(new_stree.get_class_ast())
|
|
356
|
+
replacer.replace_all(new_stree.get_ori_cls_name(), new_stree.get_opt_cls_name())
|
|
357
|
+
sub_node = TreeNode.create_tree_node(new_stree, ast_node, targets, func, call_args, call_kwargs,
|
|
358
|
+
type(cell).__name__, cell)
|
|
359
|
+
else:
|
|
360
|
+
sub_node = Node.create_call_buildin_op(cell, ast_node, targets, func, call_args, call_kwargs,
|
|
361
|
+
type(cell).__name__)
|
|
362
|
+
# add sub node to cell_container
|
|
363
|
+
cell_container.append(sub_node)
|
|
364
|
+
# set node inputs
|
|
365
|
+
if i == 0:
|
|
366
|
+
sub_node.set_inputs(first_node_inputs)
|
|
367
|
+
else:
|
|
368
|
+
sub_node.set_inputs([cell_container.node_list[i-1]])
|
|
369
|
+
return cell_container
|
|
370
|
+
|
|
371
|
+
def _process_external_function(self, stree, func_name):
|
|
372
|
+
"""Process external function."""
|
|
373
|
+
for k, m in sys.modules.items():
|
|
374
|
+
if k in ("_ast", "ast"):
|
|
375
|
+
continue
|
|
376
|
+
if hasattr(m, func_name):
|
|
377
|
+
func = getattr(m, func_name)
|
|
378
|
+
source_code = inspect.getsource(func)
|
|
379
|
+
ast_root: ast.Module = ast.parse(source_code)
|
|
380
|
+
stree._external_func_ast.append(ast_root.body[0]) # pylint: disable=protected-access
|
|
381
|
+
return func, ast_root.body[0]
|
|
382
|
+
return None, None
|
|
383
|
+
|
|
384
|
+
def _process_internal_function(self, stree: SymbolTree, func_name):
|
|
385
|
+
"""Process internal function."""
|
|
386
|
+
func = getattr(stree._origin_network, func_name) # pylint: disable=protected-access
|
|
387
|
+
ast_node = None
|
|
388
|
+
for body in stree._class_ast.body: # pylint: disable=protected-access
|
|
389
|
+
if isinstance(body, ast.FunctionDef) and func_name == body.name:
|
|
390
|
+
ast_node = body
|
|
391
|
+
return func, ast_node
|
|
392
|
+
|
|
393
|
+
def _create_func_subtree(self, op, targets, father_ast_node, ast_node, call_args, call_kwargs, func_name):
|
|
394
|
+
"""Create subtree of function."""
|
|
395
|
+
stb = FunctionSymbolTreeBuilder(op, ast_node)
|
|
396
|
+
new_stree = stb.build()
|
|
397
|
+
return TreeNode.create_tree_node(new_stree, father_ast_node, targets, func_name, call_args, call_kwargs,
|
|
398
|
+
func_name, op)
|
|
399
|
+
|
|
311
400
|
def _convert_ast_call_to_node(self, ast_node: ast.Call, father_ast_node: ast.Assign, stree: SymbolTree) -> Node:
|
|
312
401
|
"""
|
|
313
402
|
Convert ast.Call to a symbol tree node.
|
|
@@ -340,9 +429,19 @@ class AssignParser(Parser):
|
|
|
340
429
|
func = get_functional(func_name.split(".")[-1])
|
|
341
430
|
node = stree.inner_create_call_function(func_name, father_ast_node, func_name, func, targets,
|
|
342
431
|
call_args, call_kwargs)
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
432
|
+
elif hasattr(stree._origin_network, func_name): # pylint: disable=protected-access
|
|
433
|
+
func, ast_node = self._process_internal_function(stree, func_name)
|
|
434
|
+
node = self._create_func_subtree(func, targets, father_ast_node, ast_node, call_args, call_kwargs,
|
|
435
|
+
func_name)
|
|
436
|
+
else:
|
|
437
|
+
func, ast_node = self._process_external_function(stree, func_name)
|
|
438
|
+
node = self._create_func_subtree(func, targets, father_ast_node, ast_node, call_args, call_kwargs,
|
|
439
|
+
func_name)
|
|
440
|
+
return node
|
|
441
|
+
if isinstance(op, SequentialCell):
|
|
442
|
+
node = self._cell_container_process(father_ast_node, stree, targets, func, call_args, call_kwargs,
|
|
443
|
+
func_name, op)
|
|
444
|
+
return node
|
|
346
445
|
if isinstance(op, Primitive):
|
|
347
446
|
return Node.create_call_buildin_op(op, father_ast_node, targets, func, call_args, call_kwargs, func_name)
|
|
348
447
|
if isinstance(op, Cell):
|
|
@@ -394,6 +493,74 @@ class AssignParser(Parser):
|
|
|
394
493
|
raise RuntimeError("For MindSpore Rewrite, only support Primitive or Cell operator or Primitive operator, got ",
|
|
395
494
|
type(op).__name__)
|
|
396
495
|
|
|
496
|
+
@staticmethod
|
|
497
|
+
def _tuple_elts_support_scopledvalue(value: ast.Tuple) -> bool:
|
|
498
|
+
""" check whether each element's type in tuple is supported by scopled value. """
|
|
499
|
+
if not isinstance(value, ast.Tuple):
|
|
500
|
+
raise RuntimeError("For AssignParser._tuple_elts_support_scopledvalue(), the type of value should be "
|
|
501
|
+
f"Tuple, but got {type(value).__name__}")
|
|
502
|
+
|
|
503
|
+
for elt in value.elts:
|
|
504
|
+
if not isinstance(elt, (ast.Name, ast.Attribute, ast.Tuple, ast.Constant, ast.Num, ast.Str, ast.Bytes)):
|
|
505
|
+
return False
|
|
506
|
+
return True
|
|
507
|
+
|
|
508
|
+
@staticmethod
|
|
509
|
+
def _convert_ast_mathops_to_node(ast_node: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare],
|
|
510
|
+
father_ast_node: ast.Assign) -> Node:
|
|
511
|
+
"""
|
|
512
|
+
Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to
|
|
513
|
+
a symbol tree node.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
ast_node (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival
|
|
517
|
+
operation in construct function.
|
|
518
|
+
father_ast_node (ast.Assign): Assign node in construct.
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
An instance of Node in Symbol Tree.
|
|
522
|
+
|
|
523
|
+
Raises:
|
|
524
|
+
TypeError: The type of parameter 'ast_node' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare).
|
|
525
|
+
|
|
526
|
+
"""
|
|
527
|
+
if not isinstance(ast_node, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
|
528
|
+
raise TypeError("The type of parameter 'ast_node' must be one of (ast.BinOp, ast.UnaryOp, "
|
|
529
|
+
"ast.BoolOp, ast.Compare), but got ", type(ast_node))
|
|
530
|
+
|
|
531
|
+
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(father_ast_node.targets[0]))
|
|
532
|
+
args = []
|
|
533
|
+
op_type_str = type(ast_node).__name__
|
|
534
|
+
op_type = ScopedValue.create_naming_value(op_type_str)
|
|
535
|
+
ops = {}
|
|
536
|
+
name = op_type_str
|
|
537
|
+
if isinstance(ast_node, ast.BinOp):
|
|
538
|
+
op = type(ast_node.op).__name__
|
|
539
|
+
name = f'{name}_{op}'
|
|
540
|
+
ops['0'] = ScopedValue.create_naming_value(op)
|
|
541
|
+
args.append(AssignParser._create_scopedvalue(ast_node.left))
|
|
542
|
+
args.append(AssignParser._create_scopedvalue(ast_node.right))
|
|
543
|
+
elif isinstance(ast_node, ast.UnaryOp):
|
|
544
|
+
op = type(ast_node.op).__name__
|
|
545
|
+
name = f'{name}_{op}'
|
|
546
|
+
ops['0'] = ScopedValue.create_naming_value(op)
|
|
547
|
+
args.append(AssignParser._create_scopedvalue(ast_node.operand))
|
|
548
|
+
elif isinstance(ast_node, ast.BoolOp):
|
|
549
|
+
op = type(ast_node.op).__name__
|
|
550
|
+
name = f'{name}_{op}'
|
|
551
|
+
ops['0'] = ScopedValue.create_naming_value(op)
|
|
552
|
+
for value in ast_node.values:
|
|
553
|
+
args.append(AssignParser._create_scopedvalue(value))
|
|
554
|
+
elif isinstance(ast_node, ast.Compare):
|
|
555
|
+
args.append(AssignParser._create_scopedvalue(ast_node.left))
|
|
556
|
+
for idx, ast_op in enumerate(ast_node.ops):
|
|
557
|
+
op = type(ast_op).__name__
|
|
558
|
+
name = f'{name}_{op}'
|
|
559
|
+
ops[str(idx)] = ScopedValue.create_naming_value(op)
|
|
560
|
+
args.append(AssignParser._create_scopedvalue(ast_node.comparators[idx]))
|
|
561
|
+
name = name.lower()
|
|
562
|
+
return Node.create_mathops_node(father_ast_node, targets, op_type, args, ops, name)
|
|
563
|
+
|
|
397
564
|
def process(self, stree: SymbolTree, node: ast.Assign):
|
|
398
565
|
"""
|
|
399
566
|
Parse ast.Assign and create a node in symbol tree.
|
|
@@ -413,52 +580,63 @@ class AssignParser(Parser):
|
|
|
413
580
|
"""
|
|
414
581
|
|
|
415
582
|
targets = node.targets
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
if isinstance(value.op, ast.Add):
|
|
425
|
-
node_ = AssignParser._convert_ast_binop_to_node(value, node)
|
|
583
|
+
try:
|
|
584
|
+
if len(targets) != 1:
|
|
585
|
+
raise RuntimeError(
|
|
586
|
+
error_str(f"only support one target in assign now.", child_node=targets, father_node=node))
|
|
587
|
+
value = node.value
|
|
588
|
+
if isinstance(value, ast.Call):
|
|
589
|
+
stree.update_scope_for_unique(value)
|
|
590
|
+
node_ = self._convert_ast_call_to_node(value, node, stree)
|
|
426
591
|
stree.append_origin_field(node_)
|
|
427
|
-
|
|
592
|
+
elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)):
|
|
593
|
+
node_ = AssignParser._convert_ast_mathops_to_node(value, node)
|
|
594
|
+
stree.append_origin_field(node_)
|
|
595
|
+
elif isinstance(value, ast.Subscript):
|
|
428
596
|
logger.info(f"ops-call({astunparse.unparse(node)}) in assign will be supported in near feature, "
|
|
429
597
|
f"ignored as a python node now")
|
|
430
598
|
stree.try_append_python_node(node, node)
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
599
|
+
elif isinstance(value, (ast.Name, ast.Constant, ast.Attribute, ast.Num, ast.NameConstant,
|
|
600
|
+
ast.Bytes, ast.Str)):
|
|
601
|
+
if isinstance(value, ast.Name):
|
|
602
|
+
node_name = "name_assign"
|
|
603
|
+
elif isinstance(value, ast.Constant):
|
|
604
|
+
node_name = "constant_assign"
|
|
605
|
+
elif isinstance(value, ast.Attribute):
|
|
606
|
+
node_name = "attribute_assign"
|
|
607
|
+
stree.update_scope_for_unique(value)
|
|
608
|
+
else:
|
|
609
|
+
node_name = "other_assign"
|
|
610
|
+
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
|
|
611
|
+
call_args = [AssignParser._create_scopedvalue(value)]
|
|
612
|
+
node_ = Node.create_call_pass_through_method(node, targets, call_args, {}, node_name)
|
|
613
|
+
stree.append_origin_field(node_)
|
|
614
|
+
elif isinstance(value, ast.Tuple):
|
|
615
|
+
if AssignParser._tuple_elts_support_scopledvalue(value):
|
|
616
|
+
# ensure that each element's type in tuple is supported by scopled value
|
|
617
|
+
targets = AssignParser._get_targets(AssignParser._create_scopedvalue(node.targets[0]))
|
|
618
|
+
args = []
|
|
619
|
+
for elt in value.elts:
|
|
620
|
+
args.append(AssignParser._create_scopedvalue(elt))
|
|
621
|
+
node_ = Node.create_call_method(node, targets, ScopedValue.create_naming_value("tuple"),
|
|
622
|
+
args, {}, "tuple")
|
|
623
|
+
stree.append_origin_field(node_)
|
|
624
|
+
else:
|
|
625
|
+
logger.warning(f"some elements in Tuple of assign({astunparse.unparse(node)}) are not supported "
|
|
626
|
+
"in rewrite, fallback to python")
|
|
627
|
+
stree.try_append_python_node(node, node)
|
|
628
|
+
elif isinstance(value, (ast.List, ast.Dict)):
|
|
629
|
+
# add these as callmethod node if necessary
|
|
630
|
+
stree.try_append_python_node(node, node)
|
|
440
631
|
else:
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
args = []
|
|
449
|
-
for elt in value.elts:
|
|
450
|
-
args.append(AssignParser._create_scopedvalue(elt))
|
|
451
|
-
node_ = Node.create_call_method(node, targets, ScopedValue.create_naming_value("tuple"), args, {}, "tuple")
|
|
452
|
-
stree.append_origin_field(node_)
|
|
453
|
-
elif isinstance(value, (ast.List, ast.Dict)):
|
|
454
|
-
# add these as callmethod node if necessary
|
|
632
|
+
raise RuntimeError(
|
|
633
|
+
error_str(f"only support (ast.Call, ast.BinOp, ast.BoolOp, ast.Subscript, ast.Name, ast.Constant, "
|
|
634
|
+
f"ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str, ast.Tuple, ast.List, "
|
|
635
|
+
f"ast.Dict) as value of ast.assign, but got ast type '{type(value).__name__}'",
|
|
636
|
+
child_node=value, father_node=node))
|
|
637
|
+
except RuntimeError:
|
|
638
|
+
logger.info(f"ops-call({astunparse.unparse(node)}) not supported in rewrite, fallback to python")
|
|
455
639
|
stree.try_append_python_node(node, node)
|
|
456
|
-
else:
|
|
457
|
-
raise RuntimeError(
|
|
458
|
-
error_str(f"only support (ast.Call, ast.BinOp, ast.BoolOp, ast.Subscript, ast.Name, ast.Constant, "
|
|
459
|
-
f"ast.Attribute, ast.Num, ast.NameConstant, ast.Bytes, ast.Str, ast.Tuple, ast.List, ast.Dict"
|
|
460
|
-
f") as value of ast.assign, but got ast type '{type(value).__name__}'", child_node=value,
|
|
461
|
-
father_node=node))
|
|
462
640
|
|
|
463
641
|
|
|
464
642
|
g_assign_parser = reg_parser(AssignParser())
|