mindspore 2.0.0a0__cp37-none-any.whl → 2.0.0rc1__cp37-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +9064 -0
- mindspore/__init__.py +4 -2
- mindspore/_akg/akg/composite/build_module.py +11 -0
- mindspore/_akg/akg/config/repository_cuda.json +11 -0
- mindspore/_akg/akg/tvm/contrib/nvcc.py +4 -3
- mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +4 -3
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -48
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +16 -1
- mindspore/_extends/parse/parser.py +107 -22
- mindspore/_extends/parse/resources.py +0 -7
- mindspore/_extends/parse/standard_method.py +885 -413
- mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +52 -57
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +38 -20
- mindspore/boost/dim_reduce.py +3 -3
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -6
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +41 -7
- mindspore/common/api.py +215 -141
- mindspore/common/dtype.py +8 -1
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +4 -2
- mindspore/common/jit_config.py +17 -13
- mindspore/common/mutable.py +33 -13
- mindspore/common/parameter.py +23 -21
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +62 -41
- mindspore/common/tensor.py +852 -1154
- mindspore/communication/__init__.py +2 -2
- mindspore/communication/_comm_helper.py +11 -4
- mindspore/communication/management.py +22 -21
- mindspore/config/op_info.config +501 -1008
- mindspore/config/super_bar_config.json +512 -0
- mindspore/context.py +201 -23
- mindspore/dataset/__init__.py +6 -6
- mindspore/dataset/audio/__init__.py +7 -7
- mindspore/dataset/audio/transforms.py +670 -30
- mindspore/dataset/audio/utils.py +47 -4
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +210 -14
- mindspore/dataset/core/validator_helpers.py +2 -2
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +322 -66
- mindspore/dataset/engine/datasets_audio.py +80 -76
- mindspore/dataset/engine/datasets_standard_format.py +51 -38
- mindspore/dataset/engine/datasets_text.py +232 -118
- mindspore/dataset/engine/datasets_user_defined.py +41 -17
- mindspore/dataset/engine/datasets_vision.py +746 -225
- mindspore/dataset/engine/graphdata.py +75 -10
- mindspore/dataset/engine/iterators.py +45 -5
- mindspore/dataset/engine/offload.py +48 -28
- mindspore/dataset/engine/validators.py +117 -8
- mindspore/dataset/text/__init__.py +6 -5
- mindspore/dataset/text/transforms.py +86 -3
- mindspore/dataset/text/utils.py +6 -4
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +3 -2
- mindspore/dataset/transforms/c_transforms.py +1 -1
- mindspore/dataset/transforms/transforms.py +2 -2
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +2 -3
- mindspore/dataset/vision/c_transforms.py +9 -9
- mindspore/dataset/vision/py_transforms.py +5 -5
- mindspore/dataset/vision/py_transforms_util.py +2 -0
- mindspore/dataset/vision/transforms.py +160 -161
- mindspore/dataset/vision/utils.py +3 -3
- mindspore/experimental/map_parameter.py +38 -26
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +44 -9
- mindspore/include/api/delegate.h +1 -1
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_parallel_runner.h +2 -2
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +19 -3
- mindspore/include/api/types.h +3 -3
- mindspore/include/dataset/constants.h +7 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libicudata.so.69 +0 -0
- mindspore/lib/libicui18n.so.69 +0 -0
- mindspore/lib/libicuuc.so.69 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libmpi_collective.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/{libmindspore_ascend.so → libmindspore_ascend.so.2} +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filereader.py +18 -0
- mindspore/mindrecord/filewriter.py +197 -34
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +3 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +3 -3
- mindspore/mindrecord/tools/csv_to_mr.py +3 -3
- mindspore/mindrecord/tools/imagenet_to_mr.py +16 -11
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/nn/__init__.py +0 -4
- mindspore/nn/cell.py +204 -132
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +7 -6
- mindspore/nn/layer/__init__.py +5 -4
- mindspore/nn/layer/activation.py +40 -89
- mindspore/nn/layer/basic.py +255 -624
- mindspore/nn/layer/channel_shuffle.py +7 -6
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +41 -4
- mindspore/nn/layer/conv.py +64 -28
- mindspore/nn/layer/dense.py +9 -8
- mindspore/nn/layer/embedding.py +27 -25
- mindspore/nn/layer/image.py +53 -46
- mindspore/nn/layer/math.py +97 -105
- mindspore/nn/layer/normalization.py +117 -86
- mindspore/nn/layer/padding.py +185 -95
- mindspore/nn/layer/pooling.py +817 -414
- mindspore/nn/layer/rnn_cells.py +10 -15
- mindspore/nn/layer/rnns.py +37 -38
- mindspore/nn/layer/thor_layer.py +11 -12
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +5 -4
- mindspore/nn/loss/loss.py +334 -199
- mindspore/nn/optim/ada_grad.py +6 -6
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +4 -5
- mindspore/nn/optim/adam.py +126 -62
- mindspore/nn/optim/adamax.py +3 -4
- mindspore/nn/optim/adasum.py +6 -6
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +67 -38
- mindspore/nn/optim/lamb.py +4 -5
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +43 -4
- mindspore/nn/optim/momentum.py +6 -5
- mindspore/nn/optim/optimizer.py +3 -1
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +1 -1
- mindspore/nn/optim/rprop.py +8 -9
- mindspore/nn/optim/sgd.py +19 -13
- mindspore/nn/optim/thor.py +10 -15
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +4 -4
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +9 -15
- mindspore/nn/probability/distribution/bernoulli.py +3 -3
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +5 -7
- mindspore/nn/probability/distribution/cauchy.py +3 -3
- mindspore/nn/probability/distribution/distribution.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +3 -3
- mindspore/nn/probability/distribution/half_normal.py +15 -11
- mindspore/nn/probability/distribution/laplace.py +16 -13
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/normal.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/student_t.py +20 -15
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +27 -10
- mindspore/nn/wrap/grad_reducer.py +2 -2
- mindspore/nn/wrap/loss_scale.py +40 -24
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +35 -30
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +22 -19
- mindspore/numpy/utils.py +1 -1
- mindspore/numpy/utils_const.py +108 -58
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +86 -117
- mindspore/ops/_grad/grad_base.py +23 -1
- mindspore/ops/_grad/grad_clip_ops.py +2 -3
- mindspore/ops/_grad/grad_comm_ops.py +34 -24
- mindspore/ops/_grad/grad_implementations.py +9 -45
- mindspore/ops/_grad/grad_inner_ops.py +47 -4
- mindspore/ops/_grad/grad_math_ops.py +142 -117
- mindspore/ops/_grad/grad_nn_ops.py +71 -165
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +7 -6
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +150 -15
- mindspore/ops/_grad_experimental/grad_image_ops.py +16 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +1 -22
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +4 -11
- mindspore/ops/_grad_experimental/grad_math_ops.py +210 -89
- mindspore/ops/_grad_experimental/grad_nn_ops.py +26 -22
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +49 -8
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +4 -4
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +2 -2
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -5
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -4
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +236 -4
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_v1.py → adaptive_avg_pool_2d.py} +6 -5
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -43
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/aicpu/bessel_i0.py} +15 -8
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +0 -3
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/{adaptive_avg_pool_2d_grad_v1.py → digamma.py} +7 -9
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/lgamma.py} +16 -10
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +1 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/sparse_slice.py +4 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +6 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -611
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +2 -2
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/matmul_ds.py +2 -0
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/scatter_mul.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -2
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_register_for_op.py +1 -0
- mindspore/ops/_utils/__init__.py +1 -2
- mindspore/ops/_utils/utils.py +19 -40
- mindspore/ops/_vmap/vmap_array_ops.py +116 -38
- mindspore/ops/_vmap/vmap_base.py +16 -9
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +7 -5
- mindspore/ops/_vmap/vmap_image_ops.py +12 -5
- mindspore/ops/_vmap/vmap_math_ops.py +46 -5
- mindspore/ops/_vmap/vmap_nn_ops.py +15 -21
- mindspore/ops/_vmap/vmap_random_ops.py +1 -1
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +220 -106
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -23
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -17
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +39 -41
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +41 -43
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +51 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +24 -25
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +53 -53
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +77 -85
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +37 -39
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +70 -72
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +17 -17
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +2 -0
- mindspore/ops/composite/__init__.py +7 -8
- mindspore/ops/composite/base.py +101 -47
- mindspore/ops/composite/math_ops.py +188 -158
- mindspore/ops/composite/multitype_ops/_compile_utils.py +415 -170
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +142 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +2 -3
- mindspore/ops/composite/multitype_ops/getitem_impl.py +31 -3
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -5
- mindspore/ops/composite/multitype_ops/not_in_impl.py +9 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/setitem_impl.py +21 -3
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +35 -4
- mindspore/ops/function/__init__.py +152 -8
- mindspore/ops/function/array_func.py +2555 -674
- mindspore/ops/function/clip_func.py +209 -13
- mindspore/ops/function/debug_func.py +2 -2
- mindspore/ops/function/grad/__init__.py +2 -1
- mindspore/ops/function/grad/grad_func.py +147 -62
- mindspore/ops/function/image_func.py +54 -38
- mindspore/ops/function/linalg_func.py +167 -16
- mindspore/ops/function/math_func.py +4849 -1492
- mindspore/ops/function/nn_func.py +2573 -988
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +3 -3
- mindspore/ops/function/random_func.py +790 -73
- mindspore/ops/function/sparse_func.py +98 -78
- mindspore/ops/function/sparse_unary_func.py +54 -53
- mindspore/ops/function/spectral_func.py +27 -24
- mindspore/ops/function/vmap_func.py +22 -2
- mindspore/ops/functional.py +97 -37
- mindspore/ops/op_info_register.py +70 -28
- mindspore/ops/operations/__init__.py +47 -14
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +5 -5
- mindspore/ops/operations/_grad_ops.py +276 -187
- mindspore/ops/operations/_inner_ops.py +319 -113
- mindspore/ops/operations/_ms_kernel.py +10 -8
- mindspore/ops/operations/_ocr_ops.py +9 -9
- mindspore/ops/operations/_opaque_predicate_registry.py +4 -0
- mindspore/ops/operations/_quant_ops.py +137 -102
- mindspore/ops/operations/_rl_inner_ops.py +121 -60
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1004 -2
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +801 -466
- mindspore/ops/operations/comm_ops.py +51 -49
- mindspore/ops/operations/control_ops.py +2 -2
- mindspore/ops/operations/custom_ops.py +123 -44
- mindspore/ops/operations/debug_ops.py +24 -24
- mindspore/ops/operations/image_ops.py +240 -153
- mindspore/ops/operations/inner_ops.py +34 -50
- mindspore/ops/operations/linalg_ops.py +31 -9
- mindspore/ops/operations/math_ops.py +988 -757
- mindspore/ops/operations/nn_ops.py +965 -819
- mindspore/ops/operations/other_ops.py +51 -40
- mindspore/ops/operations/random_ops.py +204 -122
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +254 -93
- mindspore/ops/operations/spectral_ops.py +35 -3
- mindspore/ops/primitive.py +111 -9
- mindspore/parallel/_auto_parallel_context.py +189 -83
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +99 -7
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +7 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +6 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +20 -16
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +48 -111
- mindspore/parallel/_utils.py +1 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +37 -34
- mindspore/parallel/shard.py +17 -18
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +69 -47
- mindspore/profiler/parser/ascend_timeline_generator.py +49 -42
- mindspore/profiler/parser/base_timeline_generator.py +49 -56
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +98 -78
- mindspore/profiler/parser/hwts_log_parser.py +1 -1
- mindspore/profiler/parser/integrator.py +15 -14
- mindspore/profiler/parser/minddata_analyzer.py +2 -2
- mindspore/profiler/parser/msadvisor_analyzer.py +12 -25
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +2 -1
- mindspore/profiler/profiling.py +218 -186
- mindspore/rewrite/__init__.py +3 -1
- mindspore/rewrite/api/node.py +1 -114
- mindspore/rewrite/api/node_type.py +3 -0
- mindspore/rewrite/api/pattern_engine.py +31 -1
- mindspore/rewrite/api/scoped_value.py +4 -4
- mindspore/rewrite/api/symbol_tree.py +3 -78
- mindspore/rewrite/api/tree_node_helper.py +1 -1
- mindspore/rewrite/ast_creator_register.py +1 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -2
- mindspore/rewrite/ast_helpers/ast_creator.py +1 -2
- mindspore/rewrite/ast_helpers/ast_finder.py +65 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +11 -3
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +18 -2
- mindspore/rewrite/namespace.py +0 -2
- mindspore/rewrite/node.py +157 -11
- mindspore/rewrite/parsers/assign_parser.py +231 -53
- mindspore/rewrite/parsers/class_def_parser.py +187 -109
- mindspore/rewrite/parsers/for_parser.py +24 -14
- mindspore/rewrite/parsers/function_def_parser.py +21 -4
- mindspore/rewrite/parsers/if_parser.py +6 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +256 -133
- mindspore/rewrite/symbol_tree_builder.py +38 -1
- mindspore/run_check/_check_version.py +69 -63
- mindspore/run_check/run_check.py +2 -1
- mindspore/scipy/linalg.py +10 -114
- mindspore/scipy/ops.py +2 -2
- mindspore/scipy/ops_wrapper.py +1 -1
- mindspore/scipy/optimize/_bfgs.py +1 -1
- mindspore/scipy/optimize/_lagrange.py +200 -0
- mindspore/scipy/optimize/line_search.py +3 -2
- mindspore/scipy/optimize/minimize.py +41 -2
- mindspore/scipy/sparse/__init__.py +2 -2
- mindspore/scipy/sparse/linalg.py +4 -464
- mindspore/scipy/utils.py +1 -1
- mindspore/scipy/utils_const.py +7 -1
- mindspore/train/__init__.py +1 -1
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +273 -102
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +2 -2
- mindspore/train/callback/_checkpoint.py +3 -3
- mindspore/train/callback/_early_stop.py +3 -3
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +29 -31
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +3 -3
- mindspore/train/callback/_reduce_lr_on_plateau.py +4 -4
- mindspore/train/callback/_summary_collector.py +23 -16
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +15 -3
- mindspore/train/dataset_helper.py +10 -15
- mindspore/train/loss_scale_manager.py +8 -11
- mindspore/train/metrics/__init__.py +1 -1
- mindspore/train/metrics/bleu_score.py +1 -1
- mindspore/train/metrics/confusion_matrix.py +1 -1
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/dice.py +2 -2
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +4 -3
- mindspore/train/metrics/mean_surface_distance.py +2 -2
- mindspore/train/metrics/occlusion_sensitivity.py +1 -1
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +1 -1
- mindspore/train/metrics/recall.py +1 -1
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +45 -28
- mindspore/train/serialization.py +295 -188
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -13
- mindspore/train/train_thor/convert_utils.py +2 -2
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +3 -2
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +648 -574
- mindspore/compression/__init__.py +0 -19
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -515
- mindspore/compression/quant/__init__.py +0 -28
- mindspore/compression/quant/qat.py +0 -634
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -140
- mindspore/nn/probability/dpn/vae/vae.py +0 -124
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -364
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/aicpu/parallel_concat.py +0 -42
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Cast_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/MatMul_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/Transpose_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -15
- mindspore/ops/composite/array_ops.py +0 -241
- mindspore/ops/composite/clip_ops.py +0 -134
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.0.0a0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Sparsify transformer"""
|
|
16
|
+
import ast
|
|
17
|
+
import inspect
|
|
18
|
+
import textwrap
|
|
19
|
+
from collections import deque
|
|
20
|
+
import astunparse
|
|
21
|
+
|
|
22
|
+
from mindspore import ops, nn
|
|
23
|
+
from mindspore import log as logger
|
|
24
|
+
from mindspore.rewrite.parsers.assign_parser import AssignParser
|
|
25
|
+
from mindspore.rewrite.sparsify.utils import ArgType, SparseFunc, sparse_rules, get_sparse_func, builtin_ops, \
|
|
26
|
+
get_binop_name, get_sparse_method_outputs, arg_type_to_prefix_map, get_inputs_outputs
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
OPS_MODULE = "mindspore.ops."
|
|
30
|
+
MAX_RECURSION_DEPTH = 10
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def sparsify_helper(f, arg_types, user_defined_rules=None, sparse_name="", full_sparse_rules=None, depth=0):
|
|
34
|
+
"""Calls sparse_transformer from raw function."""
|
|
35
|
+
if isinstance(f, nn.Cell):
|
|
36
|
+
tree = ast.parse(textwrap.dedent(inspect.getsource(f.construct)))
|
|
37
|
+
# remove self
|
|
38
|
+
tree.body[0].args.args.pop(0)
|
|
39
|
+
global_vars = f.construct.__globals__
|
|
40
|
+
# pylint: disable=protected-access
|
|
41
|
+
init_vars = f._cells
|
|
42
|
+
else:
|
|
43
|
+
tree = ast.parse(textwrap.dedent(inspect.getsource(f)))
|
|
44
|
+
global_vars = f.__globals__
|
|
45
|
+
init_vars = {}
|
|
46
|
+
functiondef = tree.body[0]
|
|
47
|
+
args = [arg.arg for arg in functiondef.args.args]
|
|
48
|
+
type_map = dict(zip(args, arg_types))
|
|
49
|
+
|
|
50
|
+
sparse_transformer = SparseTransformer(
|
|
51
|
+
type_map, global_vars, init_vars, user_defined_rules, full_sparse_rules, depth)
|
|
52
|
+
sparse_tree = []
|
|
53
|
+
if not sparse_name:
|
|
54
|
+
sparse_name = functiondef.name
|
|
55
|
+
changed = False
|
|
56
|
+
for body in functiondef.body:
|
|
57
|
+
sparse_body = sparse_transformer.transform(body)
|
|
58
|
+
changed |= sparse_transformer.has_changed()
|
|
59
|
+
sparse_tree.append(sparse_body)
|
|
60
|
+
return_types = sparse_transformer.return_types
|
|
61
|
+
|
|
62
|
+
if changed:
|
|
63
|
+
sparse_tree = list(x[0] for x in sparse_transformer.sparse_functiondef.values()) + sparse_tree
|
|
64
|
+
ast_module = ast.Module([ast.FunctionDef(
|
|
65
|
+
sparse_name, functiondef.args, sparse_tree, functiondef.decorator_list, functiondef.returns)])
|
|
66
|
+
return ast_module, True, return_types
|
|
67
|
+
return tree, False, return_types
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SparseTransformer(ast.NodeTransformer):
|
|
71
|
+
"""Transformer class for sparsify."""
|
|
72
|
+
def __init__(self, type_map, global_vars, init_vars, user_defined_rules=None, full_sparse_rules=None, depth=0):
|
|
73
|
+
"""Init method."""
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.type_map = type_map
|
|
76
|
+
self.global_vars = global_vars
|
|
77
|
+
self.init_vars = init_vars
|
|
78
|
+
self.depth = depth
|
|
79
|
+
self.return_types = (ArgType.NONSPARSE,)
|
|
80
|
+
# maps function name and arg types to sparsified ast and return types, which are then inserted into module
|
|
81
|
+
self.sparse_functiondef = {}
|
|
82
|
+
# maps function name and arg types to return types for ast that do not change after sparsify
|
|
83
|
+
self.origin_functiondef = {}
|
|
84
|
+
|
|
85
|
+
# keeps track of arg_type for each operand on the call stack recursively
|
|
86
|
+
self._frames = deque()
|
|
87
|
+
self._changed = False
|
|
88
|
+
# variables for which arg_types diverge with control flow are not supported, and are considered dead
|
|
89
|
+
# after exiting the block
|
|
90
|
+
self._dead_vars = {}
|
|
91
|
+
# full_sparse_rules are inherited from caller cell and takes precedence over generic rules
|
|
92
|
+
if full_sparse_rules:
|
|
93
|
+
self.full_sparse_rules = full_sparse_rules
|
|
94
|
+
else:
|
|
95
|
+
self.full_sparse_rules = {}
|
|
96
|
+
user_defined_rules = user_defined_rules or {}
|
|
97
|
+
self.get_sparse_rules(user_defined_rules)
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def make_call(node, name="", args=None):
|
|
101
|
+
"""Returns a call node with given name and args, if provided."""
|
|
102
|
+
if name:
|
|
103
|
+
func = ast.Name(name, ast.Load())
|
|
104
|
+
else:
|
|
105
|
+
func = node.func
|
|
106
|
+
if args is None:
|
|
107
|
+
args = node.args
|
|
108
|
+
return ast.Call(func, args, node.keywords)
|
|
109
|
+
|
|
110
|
+
def get_sparse_rules(self, user_defined_rules):
|
|
111
|
+
"""Generates sparse rules for the transformer from generic sparse rules and user-defined sparse rules."""
|
|
112
|
+
for func, rules in {**sparse_rules, **user_defined_rules}.items():
|
|
113
|
+
for r in rules:
|
|
114
|
+
sparse_func = get_sparse_func(r)
|
|
115
|
+
# sparse rules are accessed by the function object and input arg_types pair
|
|
116
|
+
sparse_func_map = self.full_sparse_rules.get(func, {})
|
|
117
|
+
sparse_func_map[tuple(sparse_func.inputs)] = sparse_func
|
|
118
|
+
self.full_sparse_rules[func] = sparse_func_map
|
|
119
|
+
|
|
120
|
+
def transform(self, node):
|
|
121
|
+
"""Transforms a single node which represents a stmt in the ast."""
|
|
122
|
+
self.clear_stack()
|
|
123
|
+
self._changed = False
|
|
124
|
+
stmt = self.visit(node)
|
|
125
|
+
return stmt
|
|
126
|
+
|
|
127
|
+
def has_changed(self):
|
|
128
|
+
"""Whether the SparseTransformer has changed"""
|
|
129
|
+
return self._changed
|
|
130
|
+
|
|
131
|
+
def add_frame(self):
|
|
132
|
+
"""Add a frame into deque."""
|
|
133
|
+
self._frames.append([])
|
|
134
|
+
|
|
135
|
+
def pop_frame(self):
|
|
136
|
+
"""Pop a frame in deque."""
|
|
137
|
+
return tuple(self._frames.pop())
|
|
138
|
+
|
|
139
|
+
def push_onto_frame(self, t):
|
|
140
|
+
"""Push an arg_type into frame deque."""
|
|
141
|
+
if not self._frames:
|
|
142
|
+
raise ValueError("Current frame not initialized!")
|
|
143
|
+
self._frames[-1].append(t)
|
|
144
|
+
|
|
145
|
+
def push_all_onto_frame(self, t):
|
|
146
|
+
"""Push all arg_types into frame deque."""
|
|
147
|
+
if not self._frames:
|
|
148
|
+
raise ValueError("Current frame not initialized!")
|
|
149
|
+
for i in t:
|
|
150
|
+
self._frames[-1].append(i)
|
|
151
|
+
|
|
152
|
+
def clear_stack(self):
|
|
153
|
+
"""Clear frame deque"""
|
|
154
|
+
self._frames.clear()
|
|
155
|
+
|
|
156
|
+
def make_sparse_func(self, func, node_type, inputs):
|
|
157
|
+
"""Returns SparseFunc by looking up sparse_rules."""
|
|
158
|
+
rules = {}
|
|
159
|
+
if node_type == ast.Call:
|
|
160
|
+
if isinstance(func, nn.Cell):
|
|
161
|
+
func_name = func.__class__.__name__.lower()
|
|
162
|
+
else:
|
|
163
|
+
func_name = getattr(func, "__name__", func)
|
|
164
|
+
elif node_type == ast.BinOp:
|
|
165
|
+
func_name = func
|
|
166
|
+
rules = self.full_sparse_rules.get(func, {})
|
|
167
|
+
|
|
168
|
+
if ArgType.ANY in rules:
|
|
169
|
+
sparse_func = rules[ArgType.ANY]
|
|
170
|
+
elif inputs in rules:
|
|
171
|
+
sparse_func = rules[inputs]
|
|
172
|
+
else:
|
|
173
|
+
# attempts to find sparse op based on sparse prefix if sparse rules not found
|
|
174
|
+
sparse_func_name = arg_type_to_prefix_map.get(inputs[0], "$") + "_" + func_name
|
|
175
|
+
sparse_op = getattr(ops, sparse_func_name, None)
|
|
176
|
+
if sparse_op is None:
|
|
177
|
+
if any(input_type != ArgType.NONSPARSE for input_type in inputs):
|
|
178
|
+
return None
|
|
179
|
+
outputs = (ArgType.NONSPARSE,)
|
|
180
|
+
else:
|
|
181
|
+
func_name = sparse_func_name
|
|
182
|
+
_, outputs = get_inputs_outputs(sparse_op)
|
|
183
|
+
sparse_func = SparseFunc(func_name, inputs, outputs)
|
|
184
|
+
|
|
185
|
+
if sparse_func.fn != func:
|
|
186
|
+
self._changed = True
|
|
187
|
+
return sparse_func
|
|
188
|
+
|
|
189
|
+
def get_sparse_node(self, node, args, func, arg_types):
|
|
190
|
+
"""
|
|
191
|
+
Retrieves target from sparse rules if matches, otherwise sparsify the node by recursively expanding `func`
|
|
192
|
+
until maximum recursion depth is reached. Functions in mindspore.ops are not expanded.
|
|
193
|
+
If no matching sparse rule is found, an error is raised.
|
|
194
|
+
"""
|
|
195
|
+
sparse_func = self.make_sparse_func(func, type(node), arg_types)
|
|
196
|
+
if sparse_func is not None:
|
|
197
|
+
if self._changed:
|
|
198
|
+
func_node = ast.Name(sparse_func.fn, ast.Load())
|
|
199
|
+
if sparse_func.fn in self.global_vars:
|
|
200
|
+
func_node = ast.Name(sparse_func.fn, ast.Load())
|
|
201
|
+
else:
|
|
202
|
+
func_node = ast.Name("ops", ast.Load())
|
|
203
|
+
func_node = ast.Attribute(func_node, sparse_func.fn, ast.Load())
|
|
204
|
+
node = ast.Call(func_node, args, node.keywords)
|
|
205
|
+
self.push_all_onto_frame(sparse_func.outputs)
|
|
206
|
+
return node
|
|
207
|
+
|
|
208
|
+
if func.__module__[:len(OPS_MODULE)] == OPS_MODULE:
|
|
209
|
+
raise ValueError(f"Sparse rules not registered for {func}!")
|
|
210
|
+
|
|
211
|
+
if isinstance(func, nn.Cell):
|
|
212
|
+
class_name = func.__class__.__name__
|
|
213
|
+
func_name = class_name.lower()
|
|
214
|
+
init_args = inspect.getfullargspec(func).args
|
|
215
|
+
if len(init_args) != 1:
|
|
216
|
+
raise ValueError(f"Nested cell {class_name} with arguments for init supported!")
|
|
217
|
+
else:
|
|
218
|
+
func_name = func.__name__
|
|
219
|
+
sparse_func_name = f"sparse_{'_'.join(arg_type_to_prefix_map.get(t, 'default') for t in arg_types)}_{func_name}"
|
|
220
|
+
if (func_name, arg_types) in self.sparse_functiondef:
|
|
221
|
+
self._changed = True
|
|
222
|
+
# pylint: disable=get-dict-value-exception
|
|
223
|
+
self.push_all_onto_frame(self.sparse_functiondef[(func_name, arg_types)][1])
|
|
224
|
+
return SparseTransformer.make_call(node, sparse_func_name, args)
|
|
225
|
+
if (func_name, arg_types) in self.origin_functiondef:
|
|
226
|
+
# pylint: disable=get-dict-value-exception
|
|
227
|
+
self.push_all_onto_frame(self.origin_functiondef[(func_name, arg_types)])
|
|
228
|
+
return node
|
|
229
|
+
if self.depth == MAX_RECURSION_DEPTH:
|
|
230
|
+
raise RuntimeError(f"Maximum recursion depth {MAX_RECURSION_DEPTH} for sparsify reached at {func}!")
|
|
231
|
+
functiondef, changed, return_types = sparsify_helper(
|
|
232
|
+
func, arg_types, sparse_name=sparse_func_name, full_sparse_rules=self.full_sparse_rules,
|
|
233
|
+
depth=self.depth + 1)
|
|
234
|
+
self.push_all_onto_frame(return_types)
|
|
235
|
+
if changed:
|
|
236
|
+
self._changed = True
|
|
237
|
+
self.sparse_functiondef[(func_name, arg_types)] = (functiondef, return_types)
|
|
238
|
+
return SparseTransformer.make_call(node, sparse_func_name, args)
|
|
239
|
+
self.origin_functiondef[(func_name, arg_types)] = return_types
|
|
240
|
+
return SparseTransformer.make_call(node, args=args)
|
|
241
|
+
|
|
242
|
+
def map_type_to_target(self, node_target, value_types):
|
|
243
|
+
"""Records arg_type for each target."""
|
|
244
|
+
if isinstance(node_target, (ast.Tuple, ast.List)):
|
|
245
|
+
targets = node_target.elts
|
|
246
|
+
if len(targets) != len(value_types):
|
|
247
|
+
raise ValueError(f"Target {astunparse.unparse(node_target)} size and value size not match for "
|
|
248
|
+
f"ast.Assign {len(targets)} != {len(value_types)}")
|
|
249
|
+
target_vars = []
|
|
250
|
+
for target in targets:
|
|
251
|
+
if not isinstance(target, ast.Name):
|
|
252
|
+
raise ValueError(f"Each target {ast.dump(target)} for ast.Assign should be ast.Name!")
|
|
253
|
+
target_vars.append(target.id)
|
|
254
|
+
for var, t in zip(target_vars, value_types):
|
|
255
|
+
self.type_map[var] = t
|
|
256
|
+
elif isinstance(node_target, ast.Name):
|
|
257
|
+
var = node_target.id
|
|
258
|
+
if len(value_types) == 1:
|
|
259
|
+
self.type_map[var] = value_types[0]
|
|
260
|
+
else:
|
|
261
|
+
self.type_map[var] = value_types
|
|
262
|
+
else:
|
|
263
|
+
raise ValueError(f"Targets for ast.Assign not supported for {type(node_target)}!")
|
|
264
|
+
|
|
265
|
+
def visit_method(self, node):
|
|
266
|
+
"""Visits each node based on node class."""
|
|
267
|
+
method = "visit_" + node.__class__.__name__
|
|
268
|
+
visitor = getattr(self, method, None)
|
|
269
|
+
if visitor is None:
|
|
270
|
+
raise ValueError(f"{type(node)} is not supported in SparseTransformer!")
|
|
271
|
+
return visitor(node)
|
|
272
|
+
|
|
273
|
+
def visit(self, node):
|
|
274
|
+
"""Visitor interface for all nodes."""
|
|
275
|
+
if not node._fields:
|
|
276
|
+
return node
|
|
277
|
+
if isinstance(node, (ast.AugAssign, ast.Expr)):
|
|
278
|
+
return self.visit_generic_stmt(node)
|
|
279
|
+
if isinstance(node, (ast.BoolOp, ast.Compare, ast.Subscript)):
|
|
280
|
+
# node always evaluates to non-sparse values
|
|
281
|
+
return self.visit_generic_expr(node)
|
|
282
|
+
if isinstance(node, (ast.Tuple, ast.List, ast.UnaryOp)):
|
|
283
|
+
# node contains multiple expressions but is not composable
|
|
284
|
+
return self.visit_composite_generic_expr(node)
|
|
285
|
+
if isinstance(node, (ast.Attribute, ast.Num, ast.Str)):
|
|
286
|
+
return self.visit_scalar_expr(node)
|
|
287
|
+
if isinstance(node, (ast.Index, ast.Slice)):
|
|
288
|
+
# node forms only a part of an expression and does not exist as standalone expression
|
|
289
|
+
return self.visit_partial_expr(node)
|
|
290
|
+
return self.visit_method(node)
|
|
291
|
+
|
|
292
|
+
def visit_generic_stmt(self, node):
|
|
293
|
+
"""Visitor for generic statement."""
|
|
294
|
+
self.add_frame()
|
|
295
|
+
node = self.generic_visit(node)
|
|
296
|
+
self.pop_frame()
|
|
297
|
+
return node
|
|
298
|
+
|
|
299
|
+
def visit_scalar_expr(self, node):
|
|
300
|
+
"""Visitor for scalar expression."""
|
|
301
|
+
self.push_onto_frame(ArgType.NONSPARSE)
|
|
302
|
+
return node
|
|
303
|
+
|
|
304
|
+
def visit_generic_expr(self, node):
|
|
305
|
+
"""Visitor for generic expression."""
|
|
306
|
+
self.add_frame()
|
|
307
|
+
node = self.generic_visit(node)
|
|
308
|
+
self.pop_frame()
|
|
309
|
+
self.push_onto_frame(ArgType.NONSPARSE)
|
|
310
|
+
return node
|
|
311
|
+
|
|
312
|
+
def visit_composite_generic_expr(self, node):
|
|
313
|
+
"""Visitor for composite generic expression."""
|
|
314
|
+
return self.generic_visit(node)
|
|
315
|
+
|
|
316
|
+
def visit_partial_expr(self, node):
|
|
317
|
+
"""Visitor for a part of an expression."""
|
|
318
|
+
return node
|
|
319
|
+
|
|
320
|
+
def visit_Assign(self, node): # pylint: disable=invalid-name
|
|
321
|
+
"""Visitor for ast.Assign."""
|
|
322
|
+
self.add_frame()
|
|
323
|
+
value = self.visit(node.value)
|
|
324
|
+
value_types = self.pop_frame()
|
|
325
|
+
for node_target in node.targets:
|
|
326
|
+
self.map_type_to_target(node_target, value_types)
|
|
327
|
+
return ast.Assign(node.targets, value)
|
|
328
|
+
|
|
329
|
+
def visit_BinOp(self, node): # pylint: disable=invalid-name
|
|
330
|
+
"""Visitor for ast.Binop."""
|
|
331
|
+
self.add_frame()
|
|
332
|
+
node = self.generic_visit(node)
|
|
333
|
+
arg_types = self.pop_frame()
|
|
334
|
+
if len(arg_types) != 2:
|
|
335
|
+
raise ValueError(f"Binary op {astunparse.unparse(node)} values for arg_type len({arg_types}) != 2")
|
|
336
|
+
func = get_binop_name(node.op)
|
|
337
|
+
if func:
|
|
338
|
+
sparse_func = self.make_sparse_func(func, type(node), arg_types)
|
|
339
|
+
if sparse_func is None:
|
|
340
|
+
raise ValueError(f"Sparse rules not defined for {arg_types[0]} {func} {arg_types[1]}!")
|
|
341
|
+
outputs = sparse_func.outputs
|
|
342
|
+
else:
|
|
343
|
+
outputs = (ArgType.NONSPARSE,)
|
|
344
|
+
self.push_all_onto_frame(outputs)
|
|
345
|
+
return node
|
|
346
|
+
|
|
347
|
+
def visit_Call(self, node): # pylint: disable=invalid-name
|
|
348
|
+
"""Visitor for ast.Call."""
|
|
349
|
+
self.add_frame()
|
|
350
|
+
args = []
|
|
351
|
+
for arg in node.args:
|
|
352
|
+
args.append(self.visit(arg))
|
|
353
|
+
arg_types = self.pop_frame()
|
|
354
|
+
|
|
355
|
+
if all(t == ArgType.NONSPARSE for t in arg_types):
|
|
356
|
+
# if none of the arguments is sparse, do nothing
|
|
357
|
+
self.push_onto_frame(ArgType.NONSPARSE)
|
|
358
|
+
return node
|
|
359
|
+
|
|
360
|
+
# pylint: disable=protected-access
|
|
361
|
+
func_name = AssignParser._get_func_name(node)
|
|
362
|
+
if func_name is None or func_name == "":
|
|
363
|
+
raise RuntimeError(f"Function not exist for {ast.dump(node)}!")
|
|
364
|
+
# pylint: disable=protected-access
|
|
365
|
+
func_scope = AssignParser._get_func_scope(node)
|
|
366
|
+
|
|
367
|
+
if not func_scope:
|
|
368
|
+
if func_name in builtin_ops:
|
|
369
|
+
self.push_onto_frame(ArgType.NONSPARSE)
|
|
370
|
+
return node
|
|
371
|
+
if func_name in self.global_vars:
|
|
372
|
+
# external function with sparse arguments are inlined and cached
|
|
373
|
+
func = self.global_vars[func_name]
|
|
374
|
+
return self.get_sparse_node(node, args, func, arg_types)
|
|
375
|
+
raise ValueError(f"Call to undefined {func_name}!")
|
|
376
|
+
|
|
377
|
+
if func_scope in self.global_vars:
|
|
378
|
+
namespace = self.global_vars[func_scope]
|
|
379
|
+
func = getattr(namespace, func_name, None)
|
|
380
|
+
if func is None:
|
|
381
|
+
raise ValueError(f"{func_name} not defined in {namespace}!")
|
|
382
|
+
return self.get_sparse_node(node, args, func, arg_types)
|
|
383
|
+
|
|
384
|
+
if func_scope == "self":
|
|
385
|
+
func = self.init_vars.get(func_name, None)
|
|
386
|
+
if func is None:
|
|
387
|
+
raise ValueError(f"{func_name} not defined in in Cell.__init__!")
|
|
388
|
+
return self.get_sparse_node(node, args, func, arg_types)
|
|
389
|
+
|
|
390
|
+
func_scope_type = self.type_map.get(func_scope, None)
|
|
391
|
+
if func_scope_type is not None:
|
|
392
|
+
# tensor methods
|
|
393
|
+
if func_scope_type == ArgType.NONSPARSE:
|
|
394
|
+
outputs = (ArgType.NONSPARSE,)
|
|
395
|
+
else:
|
|
396
|
+
outputs = get_sparse_method_outputs(func_name, func_scope_type)
|
|
397
|
+
self.push_all_onto_frame(outputs)
|
|
398
|
+
return node
|
|
399
|
+
raise ValueError(f"Undefined var {func_scope}!")
|
|
400
|
+
|
|
401
|
+
def visit_Name(self, node): # pylint: disable=invalid-name
|
|
402
|
+
"""Visitor for ast.Name."""
|
|
403
|
+
if node.id in self.type_map:
|
|
404
|
+
tensor_type = self.type_map[node.id]
|
|
405
|
+
elif node.id in self.global_vars:
|
|
406
|
+
logger.warning(f"Global variable {node.id} treaded as nonsparse value by default.")
|
|
407
|
+
tensor_type = ArgType.NONSPARSE
|
|
408
|
+
elif node.id in self._dead_vars:
|
|
409
|
+
raise ValueError(f"Divergent arg_types {self._dead_vars.get(node.id)} for {node.id} are currently not "
|
|
410
|
+
f"supported in control flow and the variable is considered dead upon leaving "
|
|
411
|
+
f"the block")
|
|
412
|
+
else:
|
|
413
|
+
raise ValueError(f"Undefined variable {node.id}!")
|
|
414
|
+
|
|
415
|
+
if isinstance(tensor_type, tuple):
|
|
416
|
+
self.push_all_onto_frame(tensor_type)
|
|
417
|
+
else:
|
|
418
|
+
self.push_onto_frame(tensor_type)
|
|
419
|
+
return node
|
|
420
|
+
|
|
421
|
+
def visit_Return(self, node): # pylint: disable=invalid-name
|
|
422
|
+
"""Visitor for ast.Return."""
|
|
423
|
+
self.add_frame()
|
|
424
|
+
node = self.generic_visit(node)
|
|
425
|
+
self.return_types = self.pop_frame()
|
|
426
|
+
return node
|
|
427
|
+
|
|
428
|
+
def visit_While(self, node): # pylint: disable=invalid-name
|
|
429
|
+
"""
|
|
430
|
+
Visitor for ast.While.
|
|
431
|
+
Variables for which arg_types diverge with control flow are not supported, and as a fallback routine,
|
|
432
|
+
unsupported variables are treated as out-of-scope after leaving the control flow body.
|
|
433
|
+
"""
|
|
434
|
+
self.add_frame()
|
|
435
|
+
test = self.visit(node.test)
|
|
436
|
+
self.pop_frame()
|
|
437
|
+
orig_type_map = self.type_map.copy()
|
|
438
|
+
body = list(self.visit(expr) for expr in node.body)
|
|
439
|
+
for var, t in self.type_map.items():
|
|
440
|
+
if var not in orig_type_map:
|
|
441
|
+
# new variables in while body are considered active after the leaving the block
|
|
442
|
+
orig_type_map[var] = t
|
|
443
|
+
elif orig_type_map[var] != t:
|
|
444
|
+
# variables for which arg_types diverge are considered dead after leaving the block
|
|
445
|
+
self._dead_vars[var] = (t, orig_type_map.pop(var))
|
|
446
|
+
self.type_map = orig_type_map
|
|
447
|
+
orelse = list(self.visit(expr) for expr in node.orelse)
|
|
448
|
+
return ast.While(test, body, orelse)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""sparsify implementation"""
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
from mindspore import ops
|
|
19
|
+
from mindspore.rewrite import SymbolTree, ScopedValue
|
|
20
|
+
from mindspore.rewrite.ast_helpers import AstModifier
|
|
21
|
+
from mindspore.rewrite.sparsify.sparse_transformer import SparseTransformer
|
|
22
|
+
from mindspore.rewrite.sparsify.utils import SparseFunc, ArgType
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
op_vars = vars(ops)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_user_defined_rules(sparse_rules, global_vars, tree):
|
|
29
|
+
"""Register user-defined sparse rules."""
|
|
30
|
+
user_defined_rules = {}
|
|
31
|
+
|
|
32
|
+
def register_callable(fn):
|
|
33
|
+
func_name = fn.__name__
|
|
34
|
+
if global_vars.get(func_name, None) is fn:
|
|
35
|
+
init_targets = [ScopedValue.create_naming_value(func_name, "self")]
|
|
36
|
+
AstModifier.append_global_vars_expr_to_init(tree.get_init_func_ast(), init_targets, func_name)
|
|
37
|
+
elif not op_vars.get(func_name, None) is fn:
|
|
38
|
+
raise ValueError(f"{fn} not found in globals or mindspore.ops!")
|
|
39
|
+
|
|
40
|
+
for source, targets in sparse_rules.items():
|
|
41
|
+
if not isinstance(targets, (tuple, list)) or isinstance(targets, SparseFunc):
|
|
42
|
+
targets = [targets]
|
|
43
|
+
else:
|
|
44
|
+
targets = list(targets)
|
|
45
|
+
for sparse_func in targets:
|
|
46
|
+
if isinstance(sparse_func, SparseFunc) and callable(sparse_func.fn):
|
|
47
|
+
register_callable(sparse_func.fn)
|
|
48
|
+
elif callable(sparse_func):
|
|
49
|
+
register_callable(sparse_func)
|
|
50
|
+
rule = user_defined_rules.get(source, [])
|
|
51
|
+
rule.append(sparse_func)
|
|
52
|
+
user_defined_rules[source] = rule
|
|
53
|
+
|
|
54
|
+
return user_defined_rules
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def sparsify_tree(tree, arg_types, sparse_rules, f):
|
|
58
|
+
"""Sparsify SymbolTree object."""
|
|
59
|
+
global_vars = f.construct.__globals__
|
|
60
|
+
user_defined_rules = get_user_defined_rules(sparse_rules, global_vars, tree)
|
|
61
|
+
|
|
62
|
+
# skip self
|
|
63
|
+
args = [arg.arg for arg in tree.get_ast_root().args.args[1:]]
|
|
64
|
+
if isinstance(arg_types, tuple):
|
|
65
|
+
if len(args) != len(arg_types):
|
|
66
|
+
raise ValueError(f"arg_types should have the same length as function parameters, but "
|
|
67
|
+
f"{len(arg_types)} != {len(args)}!")
|
|
68
|
+
type_map = dict(zip(args, arg_types))
|
|
69
|
+
elif isinstance(arg_types, dict):
|
|
70
|
+
if all(isinstance(i, int) for i in arg_types.keys()):
|
|
71
|
+
type_map = {args[i]: arg_types[i] if i in arg_types else ArgType.NONSPARSE for i in range(len(args))}
|
|
72
|
+
elif all(isinstance(i, str) for i in arg_types.keys()):
|
|
73
|
+
type_map = {arg: arg_types[arg] if arg in arg_types else ArgType.NONSPARSE for arg in args}
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(f"Keys for arg_types {list(arg_types.keys())} should be all ints or all strings!")
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Unsupported type for arg_types {type(arg_types)}!")
|
|
78
|
+
|
|
79
|
+
# pylint: disable=protected-access
|
|
80
|
+
init_vars = f._cells
|
|
81
|
+
sparse_transformer = SparseTransformer(type_map, global_vars, init_vars, user_defined_rules)
|
|
82
|
+
for i, node_ast in enumerate(tree.get_ast_root().body):
|
|
83
|
+
sp_ast = sparse_transformer.transform(node_ast)
|
|
84
|
+
if sparse_transformer.has_changed():
|
|
85
|
+
tree.get_ast_root().body[i] = sp_ast
|
|
86
|
+
for module, _ in sparse_transformer.sparse_functiondef.values():
|
|
87
|
+
tree.get_module_ast().body.append(module)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def sparsify(f, arg_types, sparse_rules=None):
|
|
91
|
+
"""
|
|
92
|
+
Sparsify a Cell object by inferring the appropriate sparse function calls to replace the original function calls by
|
|
93
|
+
propagating sparse properties provided in `arg_types`.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
f (Cell): Cell object to be sparsified.
|
|
97
|
+
arg_types (Tuple[ArgType] | Dict[int, ArgType]): The type of argument (sparse csr, sparse coo,
|
|
98
|
+
non-sparse etc.) expected by `f`. If `arg_type` is a tuple, its length should be the same as the number of
|
|
99
|
+
arguments for `f`; if `arg_type` is a dictionary, each key represents an index into the arguments, and
|
|
100
|
+
arguments not referenced by the dictionary are considered to be non-sparse.
|
|
101
|
+
sparse_rules (Dict[str, SparseFunc], Optional): Additional sparse rules.
|
|
102
|
+
"""
|
|
103
|
+
os.environ["STREE_PYTHON_FALLBACK"] = "1"
|
|
104
|
+
tree = SymbolTree.create(f)
|
|
105
|
+
handler = tree.get_handler()
|
|
106
|
+
sparse_rules = sparse_rules or {}
|
|
107
|
+
sparsify_tree(handler, arg_types, sparse_rules, f)
|
|
108
|
+
os.unsetenv("STREE_PYTHON_FALLBACK")
|
|
109
|
+
return tree.get_network()
|