mindspore 1.10.0__cp39-cp39-win_amd64.whl → 2.0.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,17 +13,18 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.ClassDef which is subclass of Cell to SymbolTree."""
|
|
16
|
-
|
|
16
|
+
import sys
|
|
17
17
|
import ast
|
|
18
|
-
|
|
18
|
+
import inspect
|
|
19
19
|
import astunparse
|
|
20
20
|
from mindspore import log as logger
|
|
21
21
|
from mindspore._extends.parse.namespace import CellNamespace
|
|
22
|
+
from mindspore.rewrite.ast_creator_register import ast_creator_registry
|
|
22
23
|
from ..symbol_tree import SymbolTree
|
|
23
24
|
from ..parser import Parser
|
|
24
25
|
from ..parser_register import ParserRegister, reg_parser
|
|
25
|
-
from ..
|
|
26
|
-
from ..
|
|
26
|
+
from ..ast_helpers import AstReplacer
|
|
27
|
+
from ..common import error_str
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class AstScopeChecker:
|
|
@@ -61,7 +62,9 @@ class AstScopeChecker:
|
|
|
61
62
|
return self._check_call(node)
|
|
62
63
|
if isinstance(node, (ast.Constant, ast.NameConstant, ast.Bytes, ast.Str, ast.Num)):
|
|
63
64
|
return True
|
|
64
|
-
raise RuntimeError("
|
|
65
|
+
raise RuntimeError(error_str(f"only support (ast.Compare, ast.Attribute, ast.Name, ast.BoolOp, ast.UnaryOp"
|
|
66
|
+
f"ast.Call, ast.Constant, ast.NameConstant, ast.Bytes, ast.Str, ast.Num"
|
|
67
|
+
f") as test check, but got ast type '{type(node).__name__}'", father_node=node))
|
|
65
68
|
|
|
66
69
|
def _check_attribute(self, node: ast.Attribute):
|
|
67
70
|
"""Check an ast.Attribute meets the constraints recursively."""
|
|
@@ -113,50 +116,88 @@ class ClassDefParser(Parser):
|
|
|
113
116
|
"""Parse target type"""
|
|
114
117
|
return ast.ClassDef
|
|
115
118
|
|
|
119
|
+
def handle_father_class(self, stree, node: ast.ClassDef):
|
|
120
|
+
"""Handle father class."""
|
|
121
|
+
for base in node.bases:
|
|
122
|
+
parser: Parser = ParserRegister.instance().get_parser(type(base))
|
|
123
|
+
father_class = parser.process(stree, base)
|
|
124
|
+
if "Cell" not in father_class:
|
|
125
|
+
for k, m in sys.modules.items():
|
|
126
|
+
if k in ("_ast", "ast"):
|
|
127
|
+
continue
|
|
128
|
+
if hasattr(m, father_class):
|
|
129
|
+
cls = getattr(m, father_class)
|
|
130
|
+
source_code = inspect.getsource(cls)
|
|
131
|
+
ast_root: ast.Module = ast.parse(source_code)
|
|
132
|
+
stree._father_class_ast.append(ast_root) # pylint: disable=protected-access
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
def process(self, stree: SymbolTree, node: ast.ClassDef):
|
|
136
|
+
"""
|
|
137
|
+
Parse init and construct in ast.ClassDef.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
141
|
+
node ([ast.ClassDef]): An ast.ClassDef node.
|
|
142
|
+
"""
|
|
143
|
+
replacer = AstReplacer(node)
|
|
144
|
+
replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
|
|
145
|
+
|
|
146
|
+
stree.set_class_ast(node)
|
|
147
|
+
self.handle_father_class(stree, node)
|
|
148
|
+
|
|
149
|
+
if self._need_add_init_func(node):
|
|
150
|
+
self._add_init_func(node)
|
|
151
|
+
|
|
152
|
+
for body in node.body:
|
|
153
|
+
if isinstance(body, ast.FunctionDef):
|
|
154
|
+
if body.name == "__init__":
|
|
155
|
+
self._process_init_func_ast(stree, node, body)
|
|
156
|
+
stree.set_init_func_ast(body)
|
|
157
|
+
elif body.name == "construct":
|
|
158
|
+
parser: Parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
|
159
|
+
parser.process(stree, body)
|
|
160
|
+
else:
|
|
161
|
+
logger.info(
|
|
162
|
+
"Ignoring ast.FunctionDef in ast.ClassDef except __init__ and construct function: %s",
|
|
163
|
+
body.name)
|
|
164
|
+
else:
|
|
165
|
+
logger.info("Ignoring unsupported node(%s) in ast.ClassDef.", type(body).__name__)
|
|
166
|
+
|
|
116
167
|
def _is_subtree_field(self, ori_net, field) -> bool:
|
|
117
168
|
op = getattr(ori_net, field)
|
|
118
169
|
return not type(op).__name__ in self._cell_namespace
|
|
119
170
|
|
|
120
|
-
def _process_init_func_ast(self, stree: SymbolTree, init_ast: ast.FunctionDef):
|
|
171
|
+
def _process_init_func_ast(self, stree: SymbolTree, cls_ast: ast.ClassDef, init_ast: ast.FunctionDef):
|
|
121
172
|
"""Process init func"""
|
|
122
|
-
super_index = ClassDefParser._find_super_expr_of_init_func(init_ast)
|
|
123
173
|
ClassDefParser._modify_arguments_of_init_func(init_ast)
|
|
124
|
-
self._replace_ori_field_of_init_func(stree, init_ast.body
|
|
125
|
-
|
|
126
|
-
super_index = ClassDefParser._find_super_expr_of_init_func(init_ast)
|
|
127
|
-
ClassDefParser._insert_handler_to_init_func(init_ast, super_index)
|
|
174
|
+
new_bodies = self._replace_ori_field_of_init_func(stree, cls_ast, init_ast.body)
|
|
175
|
+
init_ast.body = new_bodies
|
|
128
176
|
|
|
129
177
|
@staticmethod
|
|
130
|
-
def
|
|
131
|
-
"""
|
|
132
|
-
if not
|
|
133
|
-
return
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
if expr_value_func.attr != "__init__" or not isinstance(expr_value_func_value, ast.Call):
|
|
148
|
-
continue
|
|
149
|
-
expr_value_func_value_func = expr_value_func_value.func
|
|
150
|
-
if not isinstance(expr_value_func_value_func, ast.Name) or expr_value_func_value_func.id != "super":
|
|
151
|
-
continue
|
|
152
|
-
break
|
|
153
|
-
return super_index
|
|
178
|
+
def _is_super_expr(expr: ast.AST) -> bool:
|
|
179
|
+
"""Check whether ast node is super().__init__()"""
|
|
180
|
+
if not isinstance(expr, ast.Expr):
|
|
181
|
+
return False
|
|
182
|
+
expr_value = expr.value
|
|
183
|
+
if not isinstance(expr_value, ast.Call):
|
|
184
|
+
return False
|
|
185
|
+
expr_value_func = expr_value.func
|
|
186
|
+
if not isinstance(expr_value_func, ast.Attribute):
|
|
187
|
+
return False
|
|
188
|
+
expr_value_func_value = expr_value_func.value
|
|
189
|
+
if expr_value_func.attr != "__init__" or not isinstance(expr_value_func_value, ast.Call):
|
|
190
|
+
return False
|
|
191
|
+
expr_value_func_value_func = expr_value_func_value.func
|
|
192
|
+
if not isinstance(expr_value_func_value_func, ast.Name) or expr_value_func_value_func.id != "super":
|
|
193
|
+
return False
|
|
194
|
+
return True
|
|
154
195
|
|
|
155
196
|
@staticmethod
|
|
156
197
|
def _modify_arguments_of_init_func(ast_init_fn: ast.FunctionDef):
|
|
157
198
|
"""Replace init function input parameters to self and global_vars."""
|
|
158
199
|
arg_self = ast.arg(arg="self", annotation="")
|
|
159
|
-
arg_global_vars = ast.arg(arg="
|
|
200
|
+
arg_global_vars = ast.arg(arg="obj", annotation="")
|
|
160
201
|
ast_init_fn.args = ast.arguments(args=[arg_self, arg_global_vars], posonlyargs=[], kwonlyargs=[],
|
|
161
202
|
kw_defaults=[], defaults=[], vararg=None, kwarg=None)
|
|
162
203
|
ast.fix_missing_locations(ast_init_fn)
|
|
@@ -187,7 +228,116 @@ class ClassDefParser(Parser):
|
|
|
187
228
|
for counter, index in enumerate(body_index_to_be_deleted):
|
|
188
229
|
bodies.pop(index - counter)
|
|
189
230
|
|
|
190
|
-
def
|
|
231
|
+
def _handle_tuple_for_replace_ori_field(self, ast_tuple: ast.Tuple, new_bodies):
|
|
232
|
+
""" Handle ast.Assign node with target of ast.Tuple in init func to new ast nodes. """
|
|
233
|
+
for e in ast_tuple.elts:
|
|
234
|
+
if isinstance(e, ast.Attribute):
|
|
235
|
+
field_name = e.attr
|
|
236
|
+
value = ast.Call(ast.Name('getattr', ast.Load()),
|
|
237
|
+
[ast.Name('obj', ast.Load()),
|
|
238
|
+
ast.Constant(value=field_name, kind=None)], [])
|
|
239
|
+
new_assign = ast_creator_registry.get("Assign")(targets=[e], value=value)
|
|
240
|
+
new_bodies.append(new_assign)
|
|
241
|
+
|
|
242
|
+
def _handle_express_for_replace_ori_field(self, cls_ast, ast_expr: ast.Expr, stree, new_bodies):
|
|
243
|
+
""" Handle ast.Expr node in init func to new ast nodes. """
|
|
244
|
+
ast_call = ast_expr.value
|
|
245
|
+
if not isinstance(ast_call.func, ast.Attribute) or not isinstance(ast_call.func.value, ast.Name)\
|
|
246
|
+
or ast_call.func.value.id != 'self':
|
|
247
|
+
return
|
|
248
|
+
for func_def in cls_ast.body:
|
|
249
|
+
if isinstance(func_def, ast.FunctionDef) and func_def.name == ast_call.func.attr:
|
|
250
|
+
for func_def_body in func_def.body:
|
|
251
|
+
self._handle_bodies_for_replace_ori_field(cls_ast, func_def_body, stree, new_bodies)
|
|
252
|
+
return
|
|
253
|
+
|
|
254
|
+
def _handle_assign_for_replace_ori_field(self, ast_assign: ast.Assign, stree, new_bodies):
|
|
255
|
+
""" Handle ast.Assign node in init func to new ast nodes. """
|
|
256
|
+
if len(ast_assign.targets) != 1:
|
|
257
|
+
raise RuntimeError("not support multi-targets in assign now!", father_node=ast_assign)
|
|
258
|
+
target = ast_assign.targets[0]
|
|
259
|
+
if isinstance(target, ast.Tuple):
|
|
260
|
+
self._handle_tuple_for_replace_ori_field(target, new_bodies)
|
|
261
|
+
return
|
|
262
|
+
if not isinstance(target, ast.Attribute) or not isinstance(target.value, ast.Name)\
|
|
263
|
+
or target.value.id != 'self':
|
|
264
|
+
logger.info(f"Ignoring {astunparse.unparse(target)} in __init__ function.")
|
|
265
|
+
return
|
|
266
|
+
field_name = target.attr
|
|
267
|
+
# Ensure that the instance has corresponding attribute
|
|
268
|
+
if not hasattr(stree.get_origin_network(), field_name):
|
|
269
|
+
return
|
|
270
|
+
# Check to avoid repeat code
|
|
271
|
+
for new_ast in new_bodies:
|
|
272
|
+
if isinstance(new_ast, ast.Assign) and isinstance(new_ast.targets[0], ast.Attribute)\
|
|
273
|
+
and new_ast.targets[0].attr == field_name:
|
|
274
|
+
return
|
|
275
|
+
value = ast.Call(ast.Name('getattr', ast.Load()),
|
|
276
|
+
[ast.Name('obj', ast.Load()),
|
|
277
|
+
ast.Constant(value=field_name, kind=None)], [])
|
|
278
|
+
new_assign = ast_creator_registry.get("Assign")(targets=[target], value=value)
|
|
279
|
+
new_bodies.append(new_assign)
|
|
280
|
+
|
|
281
|
+
def _handle_bodies_for_replace_ori_field(self, cls_ast, body, stree, new_bodies):
|
|
282
|
+
""" handle_bodies_for_replace_ori_field. """
|
|
283
|
+
if self._is_super_expr(body):
|
|
284
|
+
new_bodies.append(body)
|
|
285
|
+
return
|
|
286
|
+
if isinstance(body, ast.If):
|
|
287
|
+
for if_body in body.body + body.orelse:
|
|
288
|
+
self._handle_bodies_for_replace_ori_field(cls_ast, if_body, stree, new_bodies)
|
|
289
|
+
return
|
|
290
|
+
if isinstance(body, ast.Expr) and isinstance(body.value, ast.Call):
|
|
291
|
+
self._handle_express_for_replace_ori_field(cls_ast, body, stree, new_bodies)
|
|
292
|
+
return
|
|
293
|
+
if isinstance(body, ast.Assign): # if not assign node, delete
|
|
294
|
+
self._handle_assign_for_replace_ori_field(body, stree, new_bodies)
|
|
295
|
+
return
|
|
296
|
+
|
|
297
|
+
def _need_add_init_func(self, cls_ast: ast.ClassDef) -> bool:
|
|
298
|
+
"""If class has base nn.Cell but not have init func, then we need to add init func"""
|
|
299
|
+
base_nn_cell = False
|
|
300
|
+
for base in cls_ast.bases:
|
|
301
|
+
if isinstance(base, ast.Name) and base.id == 'Cell'\
|
|
302
|
+
or isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name)\
|
|
303
|
+
and base.value.id == "nn" and base.attr == 'Cell':
|
|
304
|
+
base_nn_cell = True
|
|
305
|
+
break
|
|
306
|
+
if not base_nn_cell:
|
|
307
|
+
return False
|
|
308
|
+
for body in cls_ast.body:
|
|
309
|
+
if isinstance(body, ast.FunctionDef) and body.name == '__init__':
|
|
310
|
+
return False
|
|
311
|
+
return True
|
|
312
|
+
|
|
313
|
+
def _add_init_func(self, cls_ast: ast.ClassDef):
|
|
314
|
+
"""Add init func with super().__init__()"""
|
|
315
|
+
init_func_ast = ast.FunctionDef(
|
|
316
|
+
name='__init__',
|
|
317
|
+
args=ast.arguments(
|
|
318
|
+
posonlyargs=[],
|
|
319
|
+
args=[
|
|
320
|
+
ast.arg(arg='self')],
|
|
321
|
+
kwonlyargs=[],
|
|
322
|
+
kw_defaults=[],
|
|
323
|
+
defaults=[]),
|
|
324
|
+
body=[
|
|
325
|
+
ast.Expr(
|
|
326
|
+
value=ast.Call(
|
|
327
|
+
func=ast.Attribute(
|
|
328
|
+
value=ast.Call(
|
|
329
|
+
func=ast.Name(id='super', ctx=ast.Load()),
|
|
330
|
+
args=[],
|
|
331
|
+
keywords=[]),
|
|
332
|
+
attr='__init__',
|
|
333
|
+
ctx=ast.Load()),
|
|
334
|
+
args=[],
|
|
335
|
+
keywords=[]))],
|
|
336
|
+
decorator_list=[])
|
|
337
|
+
cls_ast.body.insert(0, init_func_ast)
|
|
338
|
+
ast.fix_missing_locations(cls_ast)
|
|
339
|
+
|
|
340
|
+
def _replace_ori_field_of_init_func(self, stree: SymbolTree, cls_ast: ast.ClassDef, bodies: []):
|
|
191
341
|
"""
|
|
192
342
|
Replace original field in init func to self.XX = getattr(self._handler, "XX").
|
|
193
343
|
Only keep following two kinds of ast nodes in bodies right now:
|
|
@@ -196,83 +346,16 @@ class ClassDefParser(Parser):
|
|
|
196
346
|
|
|
197
347
|
Args:
|
|
198
348
|
bodies ([]): bodied of init ast.FunctionDef.
|
|
199
|
-
super_index (int): index of super().__init__() in bodies.
|
|
200
349
|
|
|
201
350
|
Raises:
|
|
202
351
|
RuntimeError: Not support multi-targets in assign.
|
|
203
352
|
RuntimeError: Only support target.value in [ast.Name] in assign node.
|
|
204
353
|
"""
|
|
205
|
-
body_index_to_be_deleted = []
|
|
206
|
-
scope_checker = AstScopeChecker("self")
|
|
207
|
-
for body_index, body in enumerate(bodies):
|
|
208
|
-
if body_index == super_index:
|
|
209
|
-
continue # ignoring super.__init__()
|
|
210
|
-
if isinstance(body, ast.If):
|
|
211
|
-
if scope_checker.check(body.test):
|
|
212
|
-
self._replace_ori_field_of_init_func(stree, body.body, -1)
|
|
213
|
-
self._replace_ori_field_of_init_func(stree, body.orelse, -1)
|
|
214
|
-
continue
|
|
215
|
-
logger.info("Ignoring un-eval-able if: %s", astunparse.unparse(body.test))
|
|
216
|
-
if not isinstance(body, ast.Assign): # if not assign node, delete
|
|
217
|
-
body_index_to_be_deleted.append(body_index)
|
|
218
|
-
continue
|
|
219
|
-
if len(body.targets) != 1:
|
|
220
|
-
raise RuntimeError("Not support multi-targets in assign now!")
|
|
221
|
-
target = body.targets[0]
|
|
222
|
-
if not isinstance(target, ast.Attribute): # only keep class member
|
|
223
|
-
body_index_to_be_deleted.append(body_index)
|
|
224
|
-
continue
|
|
225
|
-
if not isinstance(target.value, ast.Name):
|
|
226
|
-
raise RuntimeError("Only support target.value in ast.Name now!")
|
|
227
|
-
target_value: ast.Name = target.value
|
|
228
|
-
if target_value.id != "self":
|
|
229
|
-
body_index_to_be_deleted.append(body_index)
|
|
230
|
-
continue
|
|
231
|
-
field_name = target.attr
|
|
232
|
-
body.value = ast.Call(ast.Name('getattr', ast.Load()),
|
|
233
|
-
[ast.Attribute(ast.Name('self', ast.Load()), '_handler', ast.Load()),
|
|
234
|
-
ast.Constant(value=field_name, kind=None)], [])
|
|
235
|
-
for counter, index in enumerate(body_index_to_be_deleted):
|
|
236
|
-
bodies.pop(index - counter)
|
|
237
|
-
ClassDefParser._remove_empty_ast_in_init_func(bodies)
|
|
238
|
-
|
|
239
|
-
@staticmethod
|
|
240
|
-
def _insert_handler_to_init_func(ast_init_fn: ast.FunctionDef, super_index):
|
|
241
|
-
"""Insert 'self._handler = global_vars.get('handler')' to init ast.FunctionDef.body"""
|
|
242
|
-
if super_index == -1:
|
|
243
|
-
super_index = 0
|
|
244
|
-
AstModifier.insert_assign_to_function(ast_init_fn, [ScopedValue.create_naming_value("_handler", "self")],
|
|
245
|
-
ScopedValue.create_naming_value("get", "global_vars"),
|
|
246
|
-
[ScopedValue.create_variable_value("handler")], None,
|
|
247
|
-
ast_init_fn.body[super_index], False)
|
|
248
|
-
|
|
249
|
-
def process(self, stree: SymbolTree, node: ast.ClassDef):
|
|
250
|
-
"""
|
|
251
|
-
Parse init and construct in ast.ClassDef.
|
|
252
|
-
|
|
253
|
-
Args:
|
|
254
|
-
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
255
|
-
node ([ast.ClassDef]): An ast.ClassDef node.
|
|
256
|
-
"""
|
|
257
|
-
replacer = AstReplacer(node)
|
|
258
|
-
replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name())
|
|
259
354
|
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
if body.name == "__init__":
|
|
265
|
-
self._process_init_func_ast(stree, body)
|
|
266
|
-
stree.set_init_func_ast(body)
|
|
267
|
-
elif body.name == "construct":
|
|
268
|
-
parser: Parser = ParserRegister.instance().get_parser(ast.FunctionDef)
|
|
269
|
-
parser.process(stree, body)
|
|
270
|
-
else:
|
|
271
|
-
logger.info(
|
|
272
|
-
"Ignoring ast.FunctionDef in ast.ClassDef except __init__ and construct function: %s",
|
|
273
|
-
body.name)
|
|
274
|
-
else:
|
|
275
|
-
logger.info("Ignoring unsupported node(%s) in ast.ClassDef.", type(body).__name__)
|
|
355
|
+
new_bodies = []
|
|
356
|
+
for body in bodies:
|
|
357
|
+
self._handle_bodies_for_replace_ori_field(cls_ast, body, stree, new_bodies)
|
|
358
|
+
return new_bodies
|
|
276
359
|
|
|
277
360
|
|
|
278
361
|
g_classdef_parser = reg_parser(ClassDefParser())
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Parse ast.Assign in construct function to node of SymbolTree."""
|
|
16
|
+
import ast
|
|
17
|
+
|
|
18
|
+
from mindspore.rewrite.parser import Parser
|
|
19
|
+
from mindspore.rewrite.symbol_tree import SymbolTree
|
|
20
|
+
from mindspore.rewrite.parser_register import reg_parser
|
|
21
|
+
from ..common import error_str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NameParser(Parser):
|
|
25
|
+
"""Parse ast.Name in construct function to node of SymbolTree."""
|
|
26
|
+
|
|
27
|
+
def target(self):
|
|
28
|
+
"""Parse target type."""
|
|
29
|
+
return ast.Name
|
|
30
|
+
|
|
31
|
+
def process(self, stree: SymbolTree, node: ast.Name):
|
|
32
|
+
"""
|
|
33
|
+
Parse ast.Name node.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
37
|
+
node ([ast.Name]): An ast.Name node.
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
TypeError: Name parser only supports parsing ast.Name type nodes.
|
|
41
|
+
"""
|
|
42
|
+
if not isinstance(node, ast.Name):
|
|
43
|
+
raise TypeError(error_str(f"name parser only supports parsing ast.Name type nodes, but got ast type"
|
|
44
|
+
f"'{type(node).__name__}'", father_node=node))
|
|
45
|
+
return node.id
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class NumParser(Parser):
|
|
49
|
+
"""Parse ast.Num in construct function to node of SymbolTree."""
|
|
50
|
+
|
|
51
|
+
def target(self):
|
|
52
|
+
"""Parse target type."""
|
|
53
|
+
return ast.Num
|
|
54
|
+
|
|
55
|
+
def process(self, stree: SymbolTree, node: ast.Num):
|
|
56
|
+
"""
|
|
57
|
+
Parse ast.Num node.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
61
|
+
node ([ast.Num]): An ast.Num node.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
TypeError: Num parser only supports parsing ast.Num type nodes.
|
|
65
|
+
"""
|
|
66
|
+
if not isinstance(node, ast.Num):
|
|
67
|
+
raise TypeError(error_str(f"num parser only supports parsing ast.Num type nodes, but got ast type "
|
|
68
|
+
f"'{type(node).__name__}'", father_node=node))
|
|
69
|
+
return node.n
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class StrParser(Parser):
|
|
73
|
+
"""Parse ast.Str in construct function to node of SymbolTree."""
|
|
74
|
+
|
|
75
|
+
def target(self):
|
|
76
|
+
"""Parse target type."""
|
|
77
|
+
return ast.Str
|
|
78
|
+
|
|
79
|
+
def process(self, stree: SymbolTree, node: ast.Str):
|
|
80
|
+
"""
|
|
81
|
+
Parse ast.Str node.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
85
|
+
node ([ast.Str]): An ast.Str node.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
The value of node.
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
TypeError:Str parser only supports parsing ast.Str type nodes.
|
|
92
|
+
"""
|
|
93
|
+
if not isinstance(node, ast.Str):
|
|
94
|
+
raise TypeError(error_str(f"str parser only supports parsing ast.Str type nodes, but got ast type "
|
|
95
|
+
f"'{type(node).__name__}'", father_node=node))
|
|
96
|
+
return node.s
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
g_name_parser = reg_parser(NameParser())
|
|
100
|
+
g_num_parser = reg_parser(NumParser())
|
|
101
|
+
g_str_parser = reg_parser(StrParser())
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Parse Container in construct function to node of SymbolTree."""
|
|
16
|
+
import ast
|
|
17
|
+
|
|
18
|
+
from mindspore.rewrite.parser import Parser
|
|
19
|
+
from mindspore.rewrite.symbol_tree import SymbolTree
|
|
20
|
+
from mindspore.rewrite.parser_register import ParserRegister
|
|
21
|
+
|
|
22
|
+
from mindspore.rewrite.parser_register import reg_parser
|
|
23
|
+
from ..common import error_str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ListParser(Parser):
|
|
27
|
+
"""Parse list in construct function to node of SymbolTree."""
|
|
28
|
+
|
|
29
|
+
def target(self):
|
|
30
|
+
"""Parse target type."""
|
|
31
|
+
return list
|
|
32
|
+
|
|
33
|
+
def process(self, stree: SymbolTree, node: list):
|
|
34
|
+
"""
|
|
35
|
+
Parse list.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
39
|
+
node ([list]): An list of node.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A list of value.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
TypeError:Str parser only supports parsing list type nodes.
|
|
46
|
+
"""
|
|
47
|
+
if not isinstance(node, ast.Str):
|
|
48
|
+
raise TypeError(error_str(f"str parser only supports parsing list type nodes, but got ast type "
|
|
49
|
+
f"'{type(node).__name__}'", father_node=node))
|
|
50
|
+
result = []
|
|
51
|
+
for n in node:
|
|
52
|
+
parser = ParserRegister.instance().get_parser(type(n))
|
|
53
|
+
value = parser.process(stree, n)
|
|
54
|
+
result.append(value)
|
|
55
|
+
return result
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class TupleParser(Parser):
|
|
59
|
+
"""Parse tuple in construct function to node of SymbolTree."""
|
|
60
|
+
|
|
61
|
+
def target(self):
|
|
62
|
+
"""Parse target type."""
|
|
63
|
+
return tuple
|
|
64
|
+
|
|
65
|
+
def process(self, stree: SymbolTree, node: tuple):
|
|
66
|
+
"""
|
|
67
|
+
Parse tuple.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
stree ([SymbolTree]): Symbol Tree under parsing.
|
|
71
|
+
node ([tuple]): An tuple of node.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
A tuple of value.
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
TypeError:Tuple parser only supports parsing Tuple type nodes.
|
|
78
|
+
"""
|
|
79
|
+
result = []
|
|
80
|
+
for n in node:
|
|
81
|
+
parser = ParserRegister.instance().get_parser(type(n))
|
|
82
|
+
value = parser.process(stree, n)
|
|
83
|
+
result.append(value)
|
|
84
|
+
return tuple(result)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
g_list_parser = reg_parser(ListParser())
|
|
88
|
+
g_tuple_parser = reg_parser(TupleParser())
|
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
""" Parse ast.For node """
|
|
16
|
-
from __future__ import absolute_import
|
|
17
16
|
import ast
|
|
18
17
|
import astunparse
|
|
19
18
|
|
|
20
19
|
from mindspore.rewrite.api.scoped_value import ScopedValue, ValueType
|
|
21
20
|
from mindspore.rewrite.ast_helpers.ast_modifier import AstModifier
|
|
22
21
|
from mindspore import log as logger
|
|
22
|
+
from mindspore import nn
|
|
23
23
|
from ..symbol_tree import SymbolTree
|
|
24
24
|
from ..parser import Parser
|
|
25
25
|
from ..parser_register import reg_parser
|
|
@@ -30,16 +30,18 @@ EVAL_WHITE_LIST = ("self.", "range(", "zip(", "enumerate(", "reversed(")
|
|
|
30
30
|
|
|
31
31
|
class ForParser(Parser):
|
|
32
32
|
""" Class that implements parsing ast.For nodes """
|
|
33
|
+
|
|
33
34
|
@staticmethod
|
|
34
35
|
def modify_init_ast(stree, i, obj, iter_var_name):
|
|
35
36
|
"""Modify the ast node in init function."""
|
|
36
37
|
target = f"{iter_var_name.strip()}_{str(i)}"
|
|
37
|
-
stree.
|
|
38
|
+
setattr(stree.get_origin_network(), target, obj)
|
|
38
39
|
stree.get_origin_network().insert_child_to_cell(target, obj)
|
|
39
40
|
AstModifier.insert_assign_to_function(stree.get_init_func_ast(),
|
|
40
41
|
targets=[ScopedValue(ValueType.NamingValue, "self", target)],
|
|
41
|
-
expr=ScopedValue(ValueType.NamingValue, "
|
|
42
|
-
args=[ScopedValue(ValueType.
|
|
42
|
+
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
|
|
43
|
+
args=[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
44
|
+
ScopedValue(ValueType.StringValue, "", target)])
|
|
43
45
|
|
|
44
46
|
@staticmethod
|
|
45
47
|
def modify_construct_ast(stree, ast_node, old_name, new_name):
|
|
@@ -59,20 +61,27 @@ class ForParser(Parser):
|
|
|
59
61
|
targets = node.target.id
|
|
60
62
|
iter_code = astunparse.unparse(node.iter)
|
|
61
63
|
if not iter_code.startswith(EVAL_WHITE_LIST):
|
|
62
|
-
logger.warning(
|
|
64
|
+
logger.warning(
|
|
65
|
+
f"For MindSpore Rewrtie, illegal iteration condition for For node, it must start with{EVAL_WHITE_LIST}")
|
|
63
66
|
return
|
|
64
|
-
if
|
|
67
|
+
if "self" in iter_code:
|
|
65
68
|
iter_code = iter_code.replace("self", "stree.get_origin_network()")
|
|
66
69
|
try:
|
|
67
70
|
iter_obj = eval(iter_code)
|
|
68
|
-
except
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
71
|
+
except (NameError, TypeError) as e:
|
|
72
|
+
_info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
|
|
73
|
+
f"an error occurred: {str(e)}"
|
|
74
|
+
logger.warning(_info)
|
|
75
|
+
stree.try_append_python_node(node, node)
|
|
76
|
+
return
|
|
72
77
|
|
|
73
78
|
iter_var_name = iter_code.split(".")[-1]
|
|
74
79
|
index = stree.get_ast_root().body.index(node) + 1
|
|
75
80
|
if isinstance(iter_obj, list):
|
|
81
|
+
for obj in iter_obj:
|
|
82
|
+
if not isinstance(obj, nn.Cell):
|
|
83
|
+
stree.try_append_python_node(node, node)
|
|
84
|
+
return
|
|
76
85
|
for i, obj in enumerate(iter_obj):
|
|
77
86
|
ForParser.modify_init_ast(stree, i, obj, iter_var_name)
|
|
78
87
|
for body in node.body:
|
|
@@ -82,13 +91,17 @@ class ForParser(Parser):
|
|
|
82
91
|
index += 1
|
|
83
92
|
if stree.get_ori_cls_name() == "SequentialCell":
|
|
84
93
|
stree.on_change(Event.CodeChangeEvent)
|
|
85
|
-
|
|
86
|
-
|
|
94
|
+
stree.get_ast_root().body.remove(node)
|
|
95
|
+
return
|
|
96
|
+
if isinstance(iter_obj, range):
|
|
97
|
+
logger.warning("For MindSpore Rewrite, range not support.")
|
|
87
98
|
elif isinstance(iter_obj, zip):
|
|
88
|
-
|
|
99
|
+
logger.warning("For MindSpore Rewrite, zip not support.")
|
|
89
100
|
elif isinstance(iter_obj, enumerate):
|
|
90
|
-
|
|
101
|
+
logger.warning("For MindSpore Rewrite, enumerate not support.")
|
|
91
102
|
else:
|
|
92
|
-
|
|
103
|
+
logger.warning(f"For MindSpore Rewrite, not supported type: {type(iter_obj).__name__}")
|
|
104
|
+
stree.try_append_python_node(node, node)
|
|
105
|
+
return
|
|
93
106
|
|
|
94
107
|
g_for_parser = reg_parser(ForParser())
|