mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,12 +13,12 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
16
|
import ast
|
|
18
|
-
|
|
17
|
+
from mindspore import log as logger
|
|
19
18
|
from ..parser_register import ParserRegister, reg_parser
|
|
20
19
|
from ..parser import Parser
|
|
21
20
|
from ..symbol_tree import SymbolTree
|
|
21
|
+
from ..api.node_type import NodeType
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class FunctionDefParser(Parser):
|
|
@@ -28,6 +28,24 @@ class FunctionDefParser(Parser):
|
|
|
28
28
|
"""Parse target type"""
|
|
29
29
|
return ast.FunctionDef
|
|
30
30
|
|
|
31
|
+
def remove_dead_code(self, stree: SymbolTree):
|
|
32
|
+
"""Remove dead codes"""
|
|
33
|
+
# Find out return node position
|
|
34
|
+
return_idx = -1
|
|
35
|
+
for idx, node in enumerate(stree.nodes()):
|
|
36
|
+
if node.get_node_type() == NodeType.Output:
|
|
37
|
+
return_idx = idx
|
|
38
|
+
break
|
|
39
|
+
if return_idx == -1:
|
|
40
|
+
return
|
|
41
|
+
# Remove nodes after return node.
|
|
42
|
+
# Reverse traversal to ensure that nodes are orphaned and can be deleted.
|
|
43
|
+
for idx, node in reversed(list(enumerate(stree.nodes()))):
|
|
44
|
+
if idx <= return_idx:
|
|
45
|
+
break
|
|
46
|
+
logger.info(f"Remove dead code node:{node.get_name()}")
|
|
47
|
+
stree.erase_node(node)
|
|
48
|
+
|
|
31
49
|
def process(self, stree: SymbolTree, node: ast.FunctionDef):
|
|
32
50
|
"""Parse bodies of ast.FunctionDef which is construct function to nodes of SymbolTree."""
|
|
33
51
|
stree.set_ast_root(node)
|
|
@@ -45,13 +63,11 @@ class FunctionDefParser(Parser):
|
|
|
45
63
|
else:
|
|
46
64
|
parser.process(stree, body)
|
|
47
65
|
|
|
48
|
-
for body in node.body:
|
|
49
|
-
if isinstance(body, ast.For):
|
|
50
|
-
node.body.remove(body)
|
|
51
66
|
if hasattr(node, "decorator_list"):
|
|
52
67
|
stree.try_append_python_node(node, node.decorator_list)
|
|
53
68
|
if hasattr(node, "returns"):
|
|
54
69
|
stree.try_append_python_node(node, node.returns)
|
|
70
|
+
self.remove_dead_code(stree)
|
|
55
71
|
|
|
56
72
|
|
|
57
73
|
g_functiondef_parser = reg_parser(FunctionDefParser())
|
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.If in construct function to node of SymbolTree."""
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
import ast
|
|
18
18
|
import astunparse
|
|
19
19
|
|
|
20
20
|
from ..symbol_tree import SymbolTree
|
|
21
21
|
from ..parser import Parser
|
|
22
|
-
from ..parser_register import
|
|
22
|
+
from ..parser_register import reg_parser
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class IfParser(Parser):
|
|
@@ -44,37 +44,20 @@ class IfParser(Parser):
|
|
|
44
44
|
test_code = astunparse.unparse(node.test)
|
|
45
45
|
test_code = test_code.replace("self", "stree.get_origin_network()")
|
|
46
46
|
bodies = None
|
|
47
|
-
src_bodies = None
|
|
48
|
-
dst_bodies = None
|
|
49
|
-
test_value = True
|
|
50
47
|
try:
|
|
51
48
|
test_value = eval(test_code)
|
|
52
|
-
except NameError:
|
|
49
|
+
except (NameError, TypeError):
|
|
53
50
|
stree.try_append_python_node(node, node)
|
|
54
51
|
return
|
|
55
52
|
|
|
56
53
|
bodies = node.body if test_value else node.orelse
|
|
54
|
+
index = stree.get_ast_root().body.index(node) + 1
|
|
55
|
+
info_node = ast.Name(id=f"# If node has been replaced by {bool(test_value)} branch.",
|
|
56
|
+
lineno=0, col_offset=0, ctx=ast.Load)
|
|
57
|
+
exp_node = ast.Expr(value=info_node, lineno=0, col_offset=0, ctx=ast.Load)
|
|
58
|
+
stree.get_ast_root().body.insert(index-1, exp_node)
|
|
57
59
|
for body in bodies:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
else:
|
|
62
|
-
parser.process(stree, body)
|
|
63
|
-
|
|
64
|
-
# hardcode for if, ME need both branch of ast.If has same output
|
|
65
|
-
src_bodies = node.body if test_value else node.orelse
|
|
66
|
-
dst_bodies = node.orelse if test_value else node.body
|
|
67
|
-
dst_bodies.clear()
|
|
68
|
-
if src_bodies:
|
|
69
|
-
for ast_node in src_bodies:
|
|
70
|
-
if not isinstance(ast_node, ast.Assign):
|
|
71
|
-
continue
|
|
72
|
-
targets = ast_node.targets
|
|
73
|
-
for target in targets:
|
|
74
|
-
dst_bodies.append(ast.Assign(targets=[target], value=ast.Constant(value=0, kind=None,
|
|
75
|
-
ctx=ast.Load())))
|
|
76
|
-
else:
|
|
77
|
-
dst_bodies.append(ast.Pass())
|
|
78
|
-
|
|
79
|
-
|
|
60
|
+
stree.get_ast_root().body.insert(index, body)
|
|
61
|
+
index += 1
|
|
62
|
+
stree.get_ast_root().body.remove(node)
|
|
80
63
|
g_if_parser = reg_parser(IfParser())
|
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.Module to SymbolTrees."""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
16
|
from typing import Any
|
|
18
17
|
import os
|
|
19
18
|
import ast
|
|
@@ -26,6 +25,7 @@ from ..symbol_tree import SymbolTree
|
|
|
26
25
|
from ..parser import Parser
|
|
27
26
|
from ..parser_register import ParserRegister, reg_parser
|
|
28
27
|
from ..ast_helpers import AstFinder
|
|
28
|
+
from ..common import error_str
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class ModuleParser(Parser):
|
|
@@ -41,9 +41,9 @@ class ModuleParser(Parser):
|
|
|
41
41
|
visitor = AstFinder(ast_node)
|
|
42
42
|
classes = visitor.find_all(ast.ClassDef)
|
|
43
43
|
if not classes:
|
|
44
|
-
raise RuntimeError("
|
|
44
|
+
raise RuntimeError(error_str("no class in module.", father_node=ast_node))
|
|
45
45
|
if len(classes) > 1:
|
|
46
|
-
raise RuntimeError("
|
|
46
|
+
raise RuntimeError(error_str("multi-class in module is not supported now", father_node=ast_node))
|
|
47
47
|
return classes[0]
|
|
48
48
|
|
|
49
49
|
@staticmethod
|
|
@@ -53,6 +53,7 @@ class ModuleParser(Parser):
|
|
|
53
53
|
|
|
54
54
|
class GetImportNode(ast.NodeVisitor):
|
|
55
55
|
"""Find all import nodes from input ast node."""
|
|
56
|
+
|
|
56
57
|
def visit_Import(self, node: ast.Import) -> Any:
|
|
57
58
|
"""Iterate over all nodes and save ast.Import nodes."""
|
|
58
59
|
import_nodes.append(copy.deepcopy(node))
|
|
@@ -83,13 +84,14 @@ class ModuleParser(Parser):
|
|
|
83
84
|
level=0))
|
|
84
85
|
origin_net_source_code_file = inspect.getfile(type(origin_net))
|
|
85
86
|
if not os.path.exists(origin_net_source_code_file):
|
|
86
|
-
raise RuntimeError("File ", origin_net_source_code_file,
|
|
87
|
+
raise RuntimeError("For MindSpore Rewrite, in module parser, File ", origin_net_source_code_file,
|
|
88
|
+
" not exist")
|
|
87
89
|
try:
|
|
88
90
|
with open(origin_net_source_code_file, "r") as f:
|
|
89
91
|
source_code = f.read()
|
|
90
92
|
import_nodes = ModuleParser.get_import_node(ast.parse(source_code))
|
|
91
93
|
except RuntimeError:
|
|
92
|
-
raise RuntimeError("get import nodes error")
|
|
94
|
+
raise RuntimeError("For MindSpore Rewrite, in module parser, get import nodes error")
|
|
93
95
|
if import_nodes:
|
|
94
96
|
for import_index, import_node in enumerate(import_nodes):
|
|
95
97
|
module.body.insert(import_index + 3, import_node)
|
|
@@ -105,7 +107,8 @@ class ModuleParser(Parser):
|
|
|
105
107
|
parser: Parser = ParserRegister.instance().get_parser(ast.ClassDef)
|
|
106
108
|
parser.process(stree, body)
|
|
107
109
|
else:
|
|
108
|
-
logger.info(f"
|
|
110
|
+
logger.info(f"For MindSpore Rewrite, in module parser, Ignoring unsupported "
|
|
111
|
+
f"node({astunparse.unparse(body)}) in ast.Module.")
|
|
109
112
|
|
|
110
113
|
|
|
111
114
|
g_module_parser = reg_parser(ModuleParser())
|
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.Return output-node of SymbolTree."""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
16
|
import ast
|
|
18
17
|
|
|
19
18
|
from ..symbol_tree import SymbolTree
|
|
20
19
|
from ..node import Node
|
|
21
20
|
from ..parser import Parser
|
|
22
21
|
from ..parser_register import reg_parser
|
|
22
|
+
from ..common import error_str
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class ReturnParser(Parser):
|
|
@@ -33,7 +33,8 @@ class ReturnParser(Parser):
|
|
|
33
33
|
"""Parse ast.Return to output-node of SymbolTree."""
|
|
34
34
|
return_value = node.value
|
|
35
35
|
if not isinstance(return_value, ast.Name):
|
|
36
|
-
raise RuntimeError("
|
|
36
|
+
raise RuntimeError(error_str(f"only support ast.Name as return value, but got ast type "
|
|
37
|
+
f"'{type(return_value).__name__}'", father_node=node, child_node=return_value))
|
|
37
38
|
node_return = Node.create_output_node(node, [return_value.id])
|
|
38
39
|
stree.append_origin_field(node_return)
|
|
39
40
|
|
|
File without changes
|
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Sparsify transformer"""
|
|
16
|
+
import ast
|
|
17
|
+
import inspect
|
|
18
|
+
import textwrap
|
|
19
|
+
from collections import deque
|
|
20
|
+
import astunparse
|
|
21
|
+
|
|
22
|
+
from mindspore import ops, nn
|
|
23
|
+
from mindspore import log as logger
|
|
24
|
+
from mindspore.rewrite.parsers.assign_parser import AssignParser
|
|
25
|
+
from mindspore.rewrite.sparsify.utils import ArgType, SparseFunc, sparse_rules, get_sparse_func, builtin_ops, \
|
|
26
|
+
get_binop_name, get_sparse_method_outputs, arg_type_to_prefix_map, get_inputs_outputs
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
OPS_MODULE = "mindspore.ops."
|
|
30
|
+
MAX_RECURSION_DEPTH = 10
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def sparsify_helper(f, arg_types, user_defined_rules=None, sparse_name="", full_sparse_rules=None, depth=0):
|
|
34
|
+
"""Calls sparse_transformer from raw function."""
|
|
35
|
+
if isinstance(f, nn.Cell):
|
|
36
|
+
tree = ast.parse(textwrap.dedent(inspect.getsource(f.construct)))
|
|
37
|
+
# remove self
|
|
38
|
+
tree.body[0].args.args.pop(0)
|
|
39
|
+
global_vars = f.construct.__globals__
|
|
40
|
+
# pylint: disable=protected-access
|
|
41
|
+
init_vars = f._cells
|
|
42
|
+
else:
|
|
43
|
+
tree = ast.parse(textwrap.dedent(inspect.getsource(f)))
|
|
44
|
+
global_vars = f.__globals__
|
|
45
|
+
init_vars = {}
|
|
46
|
+
functiondef = tree.body[0]
|
|
47
|
+
args = [arg.arg for arg in functiondef.args.args]
|
|
48
|
+
type_map = dict(zip(args, arg_types))
|
|
49
|
+
|
|
50
|
+
sparse_transformer = SparseTransformer(
|
|
51
|
+
type_map, global_vars, init_vars, user_defined_rules, full_sparse_rules, depth)
|
|
52
|
+
sparse_tree = []
|
|
53
|
+
if not sparse_name:
|
|
54
|
+
sparse_name = functiondef.name
|
|
55
|
+
changed = False
|
|
56
|
+
for body in functiondef.body:
|
|
57
|
+
sparse_body = sparse_transformer.transform(body)
|
|
58
|
+
changed |= sparse_transformer.has_changed()
|
|
59
|
+
sparse_tree.append(sparse_body)
|
|
60
|
+
return_types = sparse_transformer.return_types
|
|
61
|
+
|
|
62
|
+
if changed:
|
|
63
|
+
sparse_tree = list(x[0] for x in sparse_transformer.sparse_functiondef.values()) + sparse_tree
|
|
64
|
+
ast_module = ast.Module([ast.FunctionDef(
|
|
65
|
+
sparse_name, functiondef.args, sparse_tree, functiondef.decorator_list, functiondef.returns)])
|
|
66
|
+
return ast_module, True, return_types
|
|
67
|
+
return tree, False, return_types
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class SparseTransformer(ast.NodeTransformer):
|
|
71
|
+
"""Transformer class for sparsify."""
|
|
72
|
+
def __init__(self, type_map, global_vars, init_vars, user_defined_rules=None, full_sparse_rules=None, depth=0):
|
|
73
|
+
"""Init method."""
|
|
74
|
+
super().__init__()
|
|
75
|
+
self.type_map = type_map
|
|
76
|
+
self.global_vars = global_vars
|
|
77
|
+
self.init_vars = init_vars
|
|
78
|
+
self.depth = depth
|
|
79
|
+
self.return_types = (ArgType.NONSPARSE,)
|
|
80
|
+
# maps function name and arg types to sparsified ast and return types, which are then inserted into module
|
|
81
|
+
self.sparse_functiondef = {}
|
|
82
|
+
# maps function name and arg types to return types for ast that do not change after sparsify
|
|
83
|
+
self.origin_functiondef = {}
|
|
84
|
+
|
|
85
|
+
# keeps track of arg_type for each operand on the call stack recursively
|
|
86
|
+
self._frames = deque()
|
|
87
|
+
self._changed = False
|
|
88
|
+
# variables for which arg_types diverge with control flow are not supported, and are considered dead
|
|
89
|
+
# after exiting the block
|
|
90
|
+
self._dead_vars = {}
|
|
91
|
+
# full_sparse_rules are inherited from caller cell and takes precedence over generic rules
|
|
92
|
+
if full_sparse_rules:
|
|
93
|
+
self.full_sparse_rules = full_sparse_rules
|
|
94
|
+
else:
|
|
95
|
+
self.full_sparse_rules = {}
|
|
96
|
+
user_defined_rules = user_defined_rules or {}
|
|
97
|
+
self.get_sparse_rules(user_defined_rules)
|
|
98
|
+
|
|
99
|
+
@staticmethod
|
|
100
|
+
def make_call(node, name="", args=None):
|
|
101
|
+
"""Returns a call node with given name and args, if provided."""
|
|
102
|
+
if name:
|
|
103
|
+
func = ast.Name(name, ast.Load())
|
|
104
|
+
else:
|
|
105
|
+
func = node.func
|
|
106
|
+
if args is None:
|
|
107
|
+
args = node.args
|
|
108
|
+
return ast.Call(func, args, node.keywords)
|
|
109
|
+
|
|
110
|
+
def get_sparse_rules(self, user_defined_rules):
|
|
111
|
+
"""Generates sparse rules for the transformer from generic sparse rules and user-defined sparse rules."""
|
|
112
|
+
for func, rules in {**sparse_rules, **user_defined_rules}.items():
|
|
113
|
+
for r in rules:
|
|
114
|
+
sparse_func = get_sparse_func(r)
|
|
115
|
+
# sparse rules are accessed by the function object and input arg_types pair
|
|
116
|
+
sparse_func_map = self.full_sparse_rules.get(func, {})
|
|
117
|
+
sparse_func_map[tuple(sparse_func.inputs)] = sparse_func
|
|
118
|
+
self.full_sparse_rules[func] = sparse_func_map
|
|
119
|
+
|
|
120
|
+
def transform(self, node):
|
|
121
|
+
"""Transforms a single node which represents a stmt in the ast."""
|
|
122
|
+
self.clear_stack()
|
|
123
|
+
self._changed = False
|
|
124
|
+
stmt = self.visit(node)
|
|
125
|
+
return stmt
|
|
126
|
+
|
|
127
|
+
def has_changed(self):
|
|
128
|
+
"""Whether the SparseTransformer has changed"""
|
|
129
|
+
return self._changed
|
|
130
|
+
|
|
131
|
+
def add_frame(self):
|
|
132
|
+
"""Add a frame into deque."""
|
|
133
|
+
self._frames.append([])
|
|
134
|
+
|
|
135
|
+
def pop_frame(self):
|
|
136
|
+
"""Pop a frame in deque."""
|
|
137
|
+
return tuple(self._frames.pop())
|
|
138
|
+
|
|
139
|
+
def push_onto_frame(self, t):
|
|
140
|
+
"""Push an arg_type into frame deque."""
|
|
141
|
+
if not self._frames:
|
|
142
|
+
raise ValueError("Current frame not initialized!")
|
|
143
|
+
self._frames[-1].append(t)
|
|
144
|
+
|
|
145
|
+
def push_all_onto_frame(self, t):
|
|
146
|
+
"""Push all arg_types into frame deque."""
|
|
147
|
+
if not self._frames:
|
|
148
|
+
raise ValueError("Current frame not initialized!")
|
|
149
|
+
for i in t:
|
|
150
|
+
self._frames[-1].append(i)
|
|
151
|
+
|
|
152
|
+
def clear_stack(self):
|
|
153
|
+
"""Clear frame deque"""
|
|
154
|
+
self._frames.clear()
|
|
155
|
+
|
|
156
|
+
def make_sparse_func(self, func, node_type, inputs):
|
|
157
|
+
"""Returns SparseFunc by looking up sparse_rules."""
|
|
158
|
+
rules = {}
|
|
159
|
+
if node_type == ast.Call:
|
|
160
|
+
if isinstance(func, nn.Cell):
|
|
161
|
+
func_name = func.__class__.__name__.lower()
|
|
162
|
+
else:
|
|
163
|
+
func_name = getattr(func, "__name__", func)
|
|
164
|
+
elif node_type == ast.BinOp:
|
|
165
|
+
func_name = func
|
|
166
|
+
rules = self.full_sparse_rules.get(func, {})
|
|
167
|
+
|
|
168
|
+
if ArgType.ANY in rules:
|
|
169
|
+
sparse_func = rules[ArgType.ANY]
|
|
170
|
+
elif inputs in rules:
|
|
171
|
+
sparse_func = rules[inputs]
|
|
172
|
+
else:
|
|
173
|
+
# attempts to find sparse op based on sparse prefix if sparse rules not found
|
|
174
|
+
sparse_func_name = arg_type_to_prefix_map.get(inputs[0], "$") + "_" + func_name
|
|
175
|
+
sparse_op = getattr(ops, sparse_func_name, None)
|
|
176
|
+
if sparse_op is None:
|
|
177
|
+
if any(input_type != ArgType.NONSPARSE for input_type in inputs):
|
|
178
|
+
return None
|
|
179
|
+
outputs = (ArgType.NONSPARSE,)
|
|
180
|
+
else:
|
|
181
|
+
func_name = sparse_func_name
|
|
182
|
+
_, outputs = get_inputs_outputs(sparse_op)
|
|
183
|
+
sparse_func = SparseFunc(func_name, inputs, outputs)
|
|
184
|
+
|
|
185
|
+
if sparse_func.fn != func:
|
|
186
|
+
self._changed = True
|
|
187
|
+
return sparse_func
|
|
188
|
+
|
|
189
|
+
def get_sparse_node(self, node, args, func, arg_types):
|
|
190
|
+
"""
|
|
191
|
+
Retrieves target from sparse rules if matches, otherwise sparsify the node by recursively expanding `func`
|
|
192
|
+
until maximum recursion depth is reached. Functions in mindspore.ops are not expanded.
|
|
193
|
+
If no matching sparse rule is found, an error is raised.
|
|
194
|
+
"""
|
|
195
|
+
sparse_func = self.make_sparse_func(func, type(node), arg_types)
|
|
196
|
+
if sparse_func is not None:
|
|
197
|
+
if self._changed:
|
|
198
|
+
func_node = ast.Name(sparse_func.fn, ast.Load())
|
|
199
|
+
if sparse_func.fn in self.global_vars:
|
|
200
|
+
func_node = ast.Name(sparse_func.fn, ast.Load())
|
|
201
|
+
else:
|
|
202
|
+
func_node = ast.Name("ops", ast.Load())
|
|
203
|
+
func_node = ast.Attribute(func_node, sparse_func.fn, ast.Load())
|
|
204
|
+
node = ast.Call(func_node, args, node.keywords)
|
|
205
|
+
self.push_all_onto_frame(sparse_func.outputs)
|
|
206
|
+
return node
|
|
207
|
+
|
|
208
|
+
if func.__module__[:len(OPS_MODULE)] == OPS_MODULE:
|
|
209
|
+
raise ValueError(f"Sparse rules not registered for {func}!")
|
|
210
|
+
|
|
211
|
+
if isinstance(func, nn.Cell):
|
|
212
|
+
class_name = func.__class__.__name__
|
|
213
|
+
func_name = class_name.lower()
|
|
214
|
+
init_args = inspect.getfullargspec(func).args
|
|
215
|
+
if len(init_args) != 1:
|
|
216
|
+
raise ValueError(f"Nested cell {class_name} with arguments for init supported!")
|
|
217
|
+
else:
|
|
218
|
+
func_name = func.__name__
|
|
219
|
+
sparse_func_name = f"sparse_{'_'.join(arg_type_to_prefix_map.get(t, 'default') for t in arg_types)}_{func_name}"
|
|
220
|
+
if (func_name, arg_types) in self.sparse_functiondef:
|
|
221
|
+
self._changed = True
|
|
222
|
+
# pylint: disable=get-dict-value-exception
|
|
223
|
+
self.push_all_onto_frame(self.sparse_functiondef[(func_name, arg_types)][1])
|
|
224
|
+
return SparseTransformer.make_call(node, sparse_func_name, args)
|
|
225
|
+
if (func_name, arg_types) in self.origin_functiondef:
|
|
226
|
+
# pylint: disable=get-dict-value-exception
|
|
227
|
+
self.push_all_onto_frame(self.origin_functiondef[(func_name, arg_types)])
|
|
228
|
+
return node
|
|
229
|
+
if self.depth == MAX_RECURSION_DEPTH:
|
|
230
|
+
raise RuntimeError(f"Maximum recursion depth {MAX_RECURSION_DEPTH} for sparsify reached at {func}!")
|
|
231
|
+
functiondef, changed, return_types = sparsify_helper(
|
|
232
|
+
func, arg_types, sparse_name=sparse_func_name, full_sparse_rules=self.full_sparse_rules,
|
|
233
|
+
depth=self.depth + 1)
|
|
234
|
+
self.push_all_onto_frame(return_types)
|
|
235
|
+
if changed:
|
|
236
|
+
self._changed = True
|
|
237
|
+
self.sparse_functiondef[(func_name, arg_types)] = (functiondef, return_types)
|
|
238
|
+
return SparseTransformer.make_call(node, sparse_func_name, args)
|
|
239
|
+
self.origin_functiondef[(func_name, arg_types)] = return_types
|
|
240
|
+
return SparseTransformer.make_call(node, args=args)
|
|
241
|
+
|
|
242
|
+
def map_type_to_target(self, node_target, value_types):
|
|
243
|
+
"""Records arg_type for each target."""
|
|
244
|
+
if isinstance(node_target, (ast.Tuple, ast.List)):
|
|
245
|
+
targets = node_target.elts
|
|
246
|
+
if len(targets) != len(value_types):
|
|
247
|
+
raise ValueError(f"Target {astunparse.unparse(node_target)} size and value size not match for "
|
|
248
|
+
f"ast.Assign {len(targets)} != {len(value_types)}")
|
|
249
|
+
target_vars = []
|
|
250
|
+
for target in targets:
|
|
251
|
+
if not isinstance(target, ast.Name):
|
|
252
|
+
raise ValueError(f"Each target {ast.dump(target)} for ast.Assign should be ast.Name!")
|
|
253
|
+
target_vars.append(target.id)
|
|
254
|
+
for var, t in zip(target_vars, value_types):
|
|
255
|
+
self.type_map[var] = t
|
|
256
|
+
elif isinstance(node_target, ast.Name):
|
|
257
|
+
var = node_target.id
|
|
258
|
+
if len(value_types) == 1:
|
|
259
|
+
self.type_map[var] = value_types[0]
|
|
260
|
+
else:
|
|
261
|
+
self.type_map[var] = value_types
|
|
262
|
+
else:
|
|
263
|
+
raise ValueError(f"Targets for ast.Assign not supported for {type(node_target)}!")
|
|
264
|
+
|
|
265
|
+
def visit_method(self, node):
|
|
266
|
+
"""Visits each node based on node class."""
|
|
267
|
+
method = "visit_" + node.__class__.__name__
|
|
268
|
+
visitor = getattr(self, method, None)
|
|
269
|
+
if visitor is None:
|
|
270
|
+
raise ValueError(f"{type(node)} is not supported in SparseTransformer!")
|
|
271
|
+
return visitor(node)
|
|
272
|
+
|
|
273
|
+
def visit(self, node):
|
|
274
|
+
"""Visitor interface for all nodes."""
|
|
275
|
+
if not node._fields:
|
|
276
|
+
return node
|
|
277
|
+
if isinstance(node, (ast.AugAssign, ast.Expr)):
|
|
278
|
+
return self.visit_generic_stmt(node)
|
|
279
|
+
if isinstance(node, (ast.BoolOp, ast.Compare, ast.Subscript)):
|
|
280
|
+
# node always evaluates to non-sparse values
|
|
281
|
+
return self.visit_generic_expr(node)
|
|
282
|
+
if isinstance(node, (ast.Tuple, ast.List, ast.UnaryOp)):
|
|
283
|
+
# node contains multiple expressions but is not composable
|
|
284
|
+
return self.visit_composite_generic_expr(node)
|
|
285
|
+
if isinstance(node, (ast.Attribute, ast.Num, ast.Str)):
|
|
286
|
+
return self.visit_scalar_expr(node)
|
|
287
|
+
if isinstance(node, (ast.Index, ast.Slice)):
|
|
288
|
+
# node forms only a part of an expression and does not exist as standalone expression
|
|
289
|
+
return self.visit_partial_expr(node)
|
|
290
|
+
return self.visit_method(node)
|
|
291
|
+
|
|
292
|
+
def visit_generic_stmt(self, node):
|
|
293
|
+
"""Visitor for generic statement."""
|
|
294
|
+
self.add_frame()
|
|
295
|
+
node = self.generic_visit(node)
|
|
296
|
+
self.pop_frame()
|
|
297
|
+
return node
|
|
298
|
+
|
|
299
|
+
def visit_scalar_expr(self, node):
|
|
300
|
+
"""Visitor for scalar expression."""
|
|
301
|
+
self.push_onto_frame(ArgType.NONSPARSE)
|
|
302
|
+
return node
|
|
303
|
+
|
|
304
|
+
def visit_generic_expr(self, node):
|
|
305
|
+
"""Visitor for generic expression."""
|
|
306
|
+
self.add_frame()
|
|
307
|
+
node = self.generic_visit(node)
|
|
308
|
+
self.pop_frame()
|
|
309
|
+
self.push_onto_frame(ArgType.NONSPARSE)
|
|
310
|
+
return node
|
|
311
|
+
|
|
312
|
+
def visit_composite_generic_expr(self, node):
|
|
313
|
+
"""Visitor for composite generic expression."""
|
|
314
|
+
return self.generic_visit(node)
|
|
315
|
+
|
|
316
|
+
def visit_partial_expr(self, node):
|
|
317
|
+
"""Visitor for a part of an expression."""
|
|
318
|
+
return node
|
|
319
|
+
|
|
320
|
+
def visit_Assign(self, node): # pylint: disable=invalid-name
|
|
321
|
+
"""Visitor for ast.Assign."""
|
|
322
|
+
self.add_frame()
|
|
323
|
+
value = self.visit(node.value)
|
|
324
|
+
value_types = self.pop_frame()
|
|
325
|
+
for node_target in node.targets:
|
|
326
|
+
self.map_type_to_target(node_target, value_types)
|
|
327
|
+
return ast.Assign(node.targets, value)
|
|
328
|
+
|
|
329
|
+
def visit_BinOp(self, node): # pylint: disable=invalid-name
|
|
330
|
+
"""Visitor for ast.Binop."""
|
|
331
|
+
self.add_frame()
|
|
332
|
+
node = self.generic_visit(node)
|
|
333
|
+
arg_types = self.pop_frame()
|
|
334
|
+
if len(arg_types) != 2:
|
|
335
|
+
raise ValueError(f"Binary op {astunparse.unparse(node)} values for arg_type len({arg_types}) != 2")
|
|
336
|
+
func = get_binop_name(node.op)
|
|
337
|
+
if func:
|
|
338
|
+
sparse_func = self.make_sparse_func(func, type(node), arg_types)
|
|
339
|
+
if sparse_func is None:
|
|
340
|
+
raise ValueError(f"Sparse rules not defined for {arg_types[0]} {func} {arg_types[1]}!")
|
|
341
|
+
outputs = sparse_func.outputs
|
|
342
|
+
else:
|
|
343
|
+
outputs = (ArgType.NONSPARSE,)
|
|
344
|
+
self.push_all_onto_frame(outputs)
|
|
345
|
+
return node
|
|
346
|
+
|
|
347
|
+
def visit_Call(self, node): # pylint: disable=invalid-name
|
|
348
|
+
"""Visitor for ast.Call."""
|
|
349
|
+
self.add_frame()
|
|
350
|
+
args = []
|
|
351
|
+
for arg in node.args:
|
|
352
|
+
args.append(self.visit(arg))
|
|
353
|
+
arg_types = self.pop_frame()
|
|
354
|
+
|
|
355
|
+
if all(t == ArgType.NONSPARSE for t in arg_types):
|
|
356
|
+
# if none of the arguments is sparse, do nothing
|
|
357
|
+
self.push_onto_frame(ArgType.NONSPARSE)
|
|
358
|
+
return node
|
|
359
|
+
|
|
360
|
+
# pylint: disable=protected-access
|
|
361
|
+
func_name = AssignParser._get_func_name(node)
|
|
362
|
+
if func_name is None or func_name == "":
|
|
363
|
+
raise RuntimeError(f"Function not exist for {ast.dump(node)}!")
|
|
364
|
+
# pylint: disable=protected-access
|
|
365
|
+
func_scope = AssignParser._get_func_scope(node)
|
|
366
|
+
|
|
367
|
+
if not func_scope:
|
|
368
|
+
if func_name in builtin_ops:
|
|
369
|
+
self.push_onto_frame(ArgType.NONSPARSE)
|
|
370
|
+
return node
|
|
371
|
+
if func_name in self.global_vars:
|
|
372
|
+
# external function with sparse arguments are inlined and cached
|
|
373
|
+
func = self.global_vars[func_name]
|
|
374
|
+
return self.get_sparse_node(node, args, func, arg_types)
|
|
375
|
+
raise ValueError(f"Call to undefined {func_name}!")
|
|
376
|
+
|
|
377
|
+
if func_scope in self.global_vars:
|
|
378
|
+
namespace = self.global_vars[func_scope]
|
|
379
|
+
func = getattr(namespace, func_name, None)
|
|
380
|
+
if func is None:
|
|
381
|
+
raise ValueError(f"{func_name} not defined in {namespace}!")
|
|
382
|
+
return self.get_sparse_node(node, args, func, arg_types)
|
|
383
|
+
|
|
384
|
+
if func_scope == "self":
|
|
385
|
+
func = self.init_vars.get(func_name, None)
|
|
386
|
+
if func is None:
|
|
387
|
+
raise ValueError(f"{func_name} not defined in in Cell.__init__!")
|
|
388
|
+
return self.get_sparse_node(node, args, func, arg_types)
|
|
389
|
+
|
|
390
|
+
func_scope_type = self.type_map.get(func_scope, None)
|
|
391
|
+
if func_scope_type is not None:
|
|
392
|
+
# tensor methods
|
|
393
|
+
if func_scope_type == ArgType.NONSPARSE:
|
|
394
|
+
outputs = (ArgType.NONSPARSE,)
|
|
395
|
+
else:
|
|
396
|
+
outputs = get_sparse_method_outputs(func_name, func_scope_type)
|
|
397
|
+
self.push_all_onto_frame(outputs)
|
|
398
|
+
return node
|
|
399
|
+
raise ValueError(f"Undefined var {func_scope}!")
|
|
400
|
+
|
|
401
|
+
def visit_Name(self, node): # pylint: disable=invalid-name
|
|
402
|
+
"""Visitor for ast.Name."""
|
|
403
|
+
if node.id in self.type_map:
|
|
404
|
+
tensor_type = self.type_map[node.id]
|
|
405
|
+
elif node.id in self.global_vars:
|
|
406
|
+
logger.warning(f"Global variable {node.id} treaded as nonsparse value by default.")
|
|
407
|
+
tensor_type = ArgType.NONSPARSE
|
|
408
|
+
elif node.id in self._dead_vars:
|
|
409
|
+
raise ValueError(f"Divergent arg_types {self._dead_vars.get(node.id)} for {node.id} are currently not "
|
|
410
|
+
f"supported in control flow and the variable is considered dead upon leaving "
|
|
411
|
+
f"the block")
|
|
412
|
+
else:
|
|
413
|
+
raise ValueError(f"Undefined variable {node.id}!")
|
|
414
|
+
|
|
415
|
+
if isinstance(tensor_type, tuple):
|
|
416
|
+
self.push_all_onto_frame(tensor_type)
|
|
417
|
+
else:
|
|
418
|
+
self.push_onto_frame(tensor_type)
|
|
419
|
+
return node
|
|
420
|
+
|
|
421
|
+
def visit_Return(self, node): # pylint: disable=invalid-name
|
|
422
|
+
"""Visitor for ast.Return."""
|
|
423
|
+
self.add_frame()
|
|
424
|
+
node = self.generic_visit(node)
|
|
425
|
+
self.return_types = self.pop_frame()
|
|
426
|
+
return node
|
|
427
|
+
|
|
428
|
+
def visit_While(self, node): # pylint: disable=invalid-name
|
|
429
|
+
"""
|
|
430
|
+
Visitor for ast.While.
|
|
431
|
+
Variables for which arg_types diverge with control flow are not supported, and as a fallback routine,
|
|
432
|
+
unsupported variables are treated as out-of-scope after leaving the control flow body.
|
|
433
|
+
"""
|
|
434
|
+
self.add_frame()
|
|
435
|
+
test = self.visit(node.test)
|
|
436
|
+
self.pop_frame()
|
|
437
|
+
orig_type_map = self.type_map.copy()
|
|
438
|
+
body = list(self.visit(expr) for expr in node.body)
|
|
439
|
+
for var, t in self.type_map.items():
|
|
440
|
+
if var not in orig_type_map:
|
|
441
|
+
# new variables in while body are considered active after the leaving the block
|
|
442
|
+
orig_type_map[var] = t
|
|
443
|
+
elif orig_type_map[var] != t:
|
|
444
|
+
# variables for which arg_types diverge are considered dead after leaving the block
|
|
445
|
+
self._dead_vars[var] = (t, orig_type_map.pop(var))
|
|
446
|
+
self.type_map = orig_type_map
|
|
447
|
+
orelse = list(self.visit(expr) for expr in node.orelse)
|
|
448
|
+
return ast.While(test, body, orelse)
|