mindspore 1.10.0__cp38-cp38-win_amd64.whl → 2.0.0rc1__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,11 +13,14 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Ast optimizer for flatten recursive call."""
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
from typing import Any, Tuple
|
|
18
18
|
import ast
|
|
19
19
|
from ast import FunctionDef
|
|
20
|
+
import astunparse
|
|
21
|
+
|
|
20
22
|
from mindspore import log as logger
|
|
23
|
+
from ..common import error_str
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
@@ -35,7 +38,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
35
38
|
ast.Call: ["args"],
|
|
36
39
|
ast.BinOp: ["left", "right"],
|
|
37
40
|
ast.BoolOp: ["values"],
|
|
38
|
-
ast.
|
|
41
|
+
ast.UnaryOp: ["operand"],
|
|
42
|
+
ast.Compare: ["left", "comparators"],
|
|
39
43
|
}
|
|
40
44
|
|
|
41
45
|
@staticmethod
|
|
@@ -52,7 +56,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
52
56
|
target_name = "function"
|
|
53
57
|
elif isinstance(node, ast.Return):
|
|
54
58
|
target_name = "return_value"
|
|
55
|
-
elif isinstance(node, (ast.BinOp, ast.
|
|
59
|
+
elif isinstance(node, (ast.BinOp, ast.BoolOp, ast.UnaryOp)):
|
|
56
60
|
target_name = type(node.op).__name__.lower() + "_var"
|
|
57
61
|
elif isinstance(node, ast.Tuple):
|
|
58
62
|
target_name = type(node).__name__.lower() + "_var"
|
|
@@ -79,6 +83,20 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
79
83
|
new_target_name = FlattenRecursiveStmt._generate_target_name(node, target_names)
|
|
80
84
|
return new_target_name, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node)
|
|
81
85
|
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _flatten_list(node_list, target_names):
|
|
88
|
+
"""Flatten a list of node."""
|
|
89
|
+
results = list()
|
|
90
|
+
new_list = list()
|
|
91
|
+
for node in node_list:
|
|
92
|
+
if isinstance(node, ast.Call):
|
|
93
|
+
new_target, new_node = FlattenRecursiveStmt._create_new_assign_node(node, target_names)
|
|
94
|
+
results.append(new_node)
|
|
95
|
+
new_list.append(ast.Name(id=new_target, ctx=ast.Load()))
|
|
96
|
+
else:
|
|
97
|
+
new_list.append(node)
|
|
98
|
+
return results, new_list
|
|
99
|
+
|
|
82
100
|
def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]:
|
|
83
101
|
"""Flatten recursive statement according to different node type."""
|
|
84
102
|
flatten_config = self._flatten_table.get(type(node))
|
|
@@ -96,6 +114,10 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
96
114
|
else:
|
|
97
115
|
new_list.append(ast.Name(id=new_target_name, ctx=ast.Load()))
|
|
98
116
|
results.append(new_node)
|
|
117
|
+
if isinstance(todo, (ast.Tuple, tuple)):
|
|
118
|
+
_res, _new_list = FlattenRecursiveStmt._flatten_list(new_node.value.elts, [new_target_name])
|
|
119
|
+
new_node.value.elts = _new_list
|
|
120
|
+
results.extend(_res)
|
|
99
121
|
setattr(node, todo_name, new_list)
|
|
100
122
|
elif isinstance(todos, dict):
|
|
101
123
|
new_dict = []
|
|
@@ -130,7 +152,9 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
130
152
|
targets = child.targets
|
|
131
153
|
for target in targets:
|
|
132
154
|
if not isinstance(target, (ast.Name, ast.Tuple)):
|
|
133
|
-
raise RuntimeError(
|
|
155
|
+
raise RuntimeError(
|
|
156
|
+
error_str(f"currently only support ast.Name targets, but got ast type "
|
|
157
|
+
f"'{type(target).__name__}'", child_node=target, father_node=child))
|
|
134
158
|
if isinstance(target, ast.Name):
|
|
135
159
|
target_name = target.id
|
|
136
160
|
if target_name not in target_names:
|
|
@@ -138,7 +162,10 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
138
162
|
elif isinstance(target, ast.Tuple):
|
|
139
163
|
for elt in target.elts:
|
|
140
164
|
if not isinstance(elt, ast.Name):
|
|
141
|
-
raise RuntimeError(
|
|
165
|
+
raise RuntimeError(
|
|
166
|
+
error_str(f"currently only support ast.Name in ast.Tuple, "
|
|
167
|
+
f"but got ast type '{type(elt).__name__}'", child_node=elt,
|
|
168
|
+
father_node=child))
|
|
142
169
|
target_name = elt.id
|
|
143
170
|
if target_name not in target_names:
|
|
144
171
|
target_names.append(target_name)
|
|
@@ -155,6 +182,20 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
155
182
|
child = node.body[index]
|
|
156
183
|
if isinstance(child, ast.Assign):
|
|
157
184
|
stmt = child.value
|
|
185
|
+
elif isinstance(child, ast.If):
|
|
186
|
+
if isinstance(child.body[0], ast.Return) and not isinstance(child.test, ast.UnaryOp):
|
|
187
|
+
if isinstance(child.body[0].value, ast.Call):
|
|
188
|
+
if_body = child.body
|
|
189
|
+
if_func = if_body[0].value
|
|
190
|
+
expr = "x = " + astunparse.unparse(if_func)
|
|
191
|
+
if_body = ast.parse(expr)
|
|
192
|
+
if_body = if_body.body+ast.parse("return x").body
|
|
193
|
+
child.body = if_body
|
|
194
|
+
stmt = child
|
|
195
|
+
else:
|
|
196
|
+
stmt = child
|
|
197
|
+
else:
|
|
198
|
+
stmt = child
|
|
158
199
|
elif isinstance(child, ast.Expr):
|
|
159
200
|
stmt = child.value
|
|
160
201
|
else:
|
|
@@ -13,12 +13,14 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Fold if return."""
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
import ast
|
|
18
18
|
import copy
|
|
19
19
|
from typing import Any, Union
|
|
20
20
|
from enum import Enum
|
|
21
21
|
|
|
22
|
+
from ..common import error_str
|
|
23
|
+
|
|
22
24
|
|
|
23
25
|
class ReturnType(Enum):
|
|
24
26
|
"""
|
|
@@ -121,7 +123,7 @@ class RemoveReturnOutOfIf(ast.NodeTransformer):
|
|
|
121
123
|
RuntimeError: Father node has not input attr.
|
|
122
124
|
"""
|
|
123
125
|
if not hasattr(father_node, attr):
|
|
124
|
-
raise RuntimeError(
|
|
126
|
+
raise RuntimeError(error_str(f"Father node has not input attr '{attr}'", father_node=father_node))
|
|
125
127
|
father_node_attr = getattr(father_node, attr)
|
|
126
128
|
if RemoveReturnOutOfIf._last_node_is_return(if_node) == ReturnType.IfNotAllReturn:
|
|
127
129
|
# nodes should be copied to all branches which not end with return
|
|
@@ -193,7 +195,8 @@ class RemoveReturnOutOfIf(ast.NodeTransformer):
|
|
|
193
195
|
|
|
194
196
|
# assert body and or-else all end with return
|
|
195
197
|
if not isinstance(last_node.body[-1], ast.Return) or not isinstance(last_node.orelse[-1], ast.Return):
|
|
196
|
-
raise RuntimeError("Body and orelse of if nodes not all end with ast.Return."
|
|
198
|
+
raise RuntimeError(error_str("Body and orelse of if nodes not all end with ast.Return.",
|
|
199
|
+
father_node=last_node))
|
|
197
200
|
output_name = RemoveReturnOutOfIf._get_output_names(output_names)
|
|
198
201
|
# replace body return
|
|
199
202
|
body_new_last_node = ast.Assign(
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Base class, observable of observer design pattern."""
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
from .observer import Observer
|
|
18
18
|
from .event import Event
|
|
19
19
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Abstract class, observer of observer design pattern."""
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
import abc
|
|
18
18
|
from .event import Event
|
|
19
19
|
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Error Log for Rewrite."""
|
|
16
|
+
|
|
17
|
+
import ast
|
|
18
|
+
import astunparse
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def error_str(reason: str, child_node: ast.expr = None, father_node: ast.expr = None) -> str:
|
|
22
|
+
"""Raise error func for Rewrite."""
|
|
23
|
+
output_estr = "Unsupported grammar in MindSpore Rewrite, "
|
|
24
|
+
output_estr += reason
|
|
25
|
+
if child_node:
|
|
26
|
+
output_estr += "\n" + "-" * 100 + "\n"
|
|
27
|
+
output_estr += astunparse.unparse(child_node)
|
|
28
|
+
output_estr += "-" * 100 + "\n"
|
|
29
|
+
if father_node:
|
|
30
|
+
output_estr += "\nin\n"
|
|
31
|
+
if father_node:
|
|
32
|
+
output_estr += "-" * 100 + "\n"
|
|
33
|
+
output_estr += astunparse.unparse(father_node)
|
|
34
|
+
output_estr += "-" * 100 + "\n"
|
|
35
|
+
return output_estr
|
mindspore/rewrite/namer.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Unique name producer for target, name of node, class name, etc."""
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
from typing import Union
|
|
18
18
|
|
|
19
19
|
from .node import Node
|
|
@@ -153,7 +153,7 @@ class NodeNamer(Namer):
|
|
|
153
153
|
if isinstance(node_or_name, Node):
|
|
154
154
|
origin_name = node_or_name.get_name()
|
|
155
155
|
if origin_name is None or not origin_name:
|
|
156
|
-
if node_or_name.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive):
|
|
156
|
+
if node_or_name.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.CallFunction):
|
|
157
157
|
if not isinstance(node_or_name, Node):
|
|
158
158
|
raise TypeError("node_or_name should be Node, got: ", type(node_or_name))
|
|
159
159
|
targets = node_or_name.get_targets()
|
mindspore/rewrite/namespace.py
CHANGED
|
@@ -13,22 +13,32 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Define the namespace of MindSpore op definition."""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
16
|
from .._extends.parse.namespace import CellNamespace
|
|
18
17
|
|
|
19
18
|
|
|
20
19
|
_ms_common_ns = CellNamespace('mindspore.common')
|
|
21
20
|
_ms_nn_ns = CellNamespace('mindspore.nn')
|
|
22
|
-
_ms_ops_ns = CellNamespace('mindspore.ops')
|
|
21
|
+
_ms_ops_ns = CellNamespace('mindspore.ops.operations')
|
|
22
|
+
_ms_functional_ns = CellNamespace('mindspore.ops.functional')
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def is_subtree(cls_name):
|
|
26
26
|
"""Determine whether 'cls_name' is a subtree."""
|
|
27
|
-
if cls_name == "SequentialCell":
|
|
28
|
-
return True
|
|
29
27
|
if cls_name == "QuantizeWrapperCell":
|
|
30
28
|
return False
|
|
31
29
|
if cls_name in _ms_common_ns or cls_name in _ms_nn_ns or cls_name in _ms_ops_ns:
|
|
32
30
|
return False
|
|
33
31
|
|
|
34
32
|
return True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def is_functional(func_name):
|
|
36
|
+
"""Determine whether 'cls_name' is a functional."""
|
|
37
|
+
return func_name in _ms_functional_ns
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_functional(func_name):
|
|
41
|
+
"""Get the function corresponding to the func_name."""
|
|
42
|
+
if func_name in _ms_functional_ns:
|
|
43
|
+
return _ms_functional_ns[func_name]
|
|
44
|
+
return None
|
mindspore/rewrite/node.py
CHANGED
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Node class define of Rewrite. See detail in Node class docstring."""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
16
|
from typing import Optional, Union
|
|
18
17
|
import ast
|
|
19
18
|
import inspect
|
|
@@ -21,7 +20,7 @@ import inspect
|
|
|
21
20
|
from mindspore.nn import Cell
|
|
22
21
|
from mindspore.ops import Primitive
|
|
23
22
|
from mindspore import log as logger
|
|
24
|
-
from ..
|
|
23
|
+
from .. import _checkparam as Validator
|
|
25
24
|
from .ast_helpers import AstModifier
|
|
26
25
|
from .api.scoped_value import ScopedValue, ValueType
|
|
27
26
|
from .api.node_type import NodeType
|
|
@@ -85,7 +84,9 @@ class Node:
|
|
|
85
84
|
"""
|
|
86
85
|
self._node_type: NodeType = node_type
|
|
87
86
|
self._ast_node: Optional[ast.AST] = ast_node
|
|
88
|
-
self._attribute: {str, object} =
|
|
87
|
+
self._attribute: {str, object} = {}
|
|
88
|
+
if node_type in (NodeType.CallModule, NodeType.CallCell, NodeType.CallPrimitive):
|
|
89
|
+
self._attribute = Node._get_cell_or_prim_op_attribute(instance)
|
|
89
90
|
self._instance = instance
|
|
90
91
|
self._name = name
|
|
91
92
|
self._func: Optional[ScopedValue] = func
|
|
@@ -221,6 +222,32 @@ class Node:
|
|
|
221
222
|
return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), real_return_values, {},
|
|
222
223
|
name, None)
|
|
223
224
|
|
|
225
|
+
@classmethod
|
|
226
|
+
def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue],
|
|
227
|
+
op_type: ScopedValue, args: [ScopedValue],
|
|
228
|
+
ops: {str: list}, name: str = ""):
|
|
229
|
+
"""
|
|
230
|
+
Class method of Node. Instantiate an instance of node whose type is `MathOps` .
|
|
231
|
+
A mathops node is used to represent a node with mathematical operations, such as
|
|
232
|
+
`y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of
|
|
236
|
+
node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and
|
|
237
|
+
ast.Compare.
|
|
238
|
+
targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`.
|
|
239
|
+
See detail in docstring of Node class.
|
|
240
|
+
op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type.
|
|
241
|
+
args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved
|
|
242
|
+
sequentially in the list.
|
|
243
|
+
ops (dict[str:ScopedValue]): Operators participating in the mathematical operations. All operators are
|
|
244
|
+
saved sequentially in the dict, and keys are numbers in string format, such as {'0':'add', '1':'sub'}.
|
|
245
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`.
|
|
246
|
+
Name of node also used as field name in network class. The format of mathops node name
|
|
247
|
+
is 'AstNodeName_AstOpName_n'.
|
|
248
|
+
"""
|
|
249
|
+
return cls(NodeType.MathOps, ast_node, targets, op_type, args, ops, name, None)
|
|
250
|
+
|
|
224
251
|
@staticmethod
|
|
225
252
|
def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]],
|
|
226
253
|
func: Union[ScopedValue, str], args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None,
|
|
@@ -622,7 +649,9 @@ class Node:
|
|
|
622
649
|
targets ([ScopedValue]): A list of instances of ScopedValue as new targets.
|
|
623
650
|
"""
|
|
624
651
|
self._targets = targets
|
|
625
|
-
if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
|
|
652
|
+
if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive,
|
|
653
|
+
NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer,
|
|
654
|
+
NodeType.MathOps):
|
|
626
655
|
self._sync_assign_targets_to_ast()
|
|
627
656
|
|
|
628
657
|
def get_func(self) -> ScopedValue:
|
|
@@ -719,12 +748,12 @@ class Node:
|
|
|
719
748
|
ValueError: If `node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
|
|
720
749
|
"""
|
|
721
750
|
Validator.check_value_type("node", node, [Node], "Node")
|
|
722
|
-
Validator.check_int_range(arg_idx, 0, self._args_num,
|
|
751
|
+
Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx")
|
|
723
752
|
if out_idx is None:
|
|
724
753
|
if len(node._targets) != 1:
|
|
725
754
|
raise RuntimeError("node should has one output when out_idx is not provided")
|
|
726
755
|
out_idx = 0
|
|
727
|
-
Validator.check_int_range(out_idx, 0, len(node._targets),
|
|
756
|
+
Validator.check_int_range(out_idx, 0, len(node._targets), Validator.INC_LEFT, "arg_idx")
|
|
728
757
|
new_arg = node._targets[out_idx]
|
|
729
758
|
self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg
|
|
730
759
|
self._sync_arg()
|
|
@@ -741,7 +770,7 @@ class Node:
|
|
|
741
770
|
Raises:
|
|
742
771
|
ValueError: If `index` is out of range.
|
|
743
772
|
"""
|
|
744
|
-
Validator.check_int_range(index, 0, self._args_num,
|
|
773
|
+
Validator.check_int_range(index, 0, self._args_num, Validator.INC_LEFT, "index")
|
|
745
774
|
Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
|
|
746
775
|
if isinstance(arg, str):
|
|
747
776
|
arg = ScopedValue.create_naming_value(arg)
|
|
@@ -761,7 +790,7 @@ class Node:
|
|
|
761
790
|
Raises:
|
|
762
791
|
TypeError: Element of new argument is not an instance of ScopedValue.
|
|
763
792
|
"""
|
|
764
|
-
Validator.check_int_range(len(args), 0, self._args_num,
|
|
793
|
+
Validator.check_int_range(len(args), 0, self._args_num, Validator.INC_LEFT, "Length of args")
|
|
765
794
|
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
766
795
|
for arg_index, arg in enumerate(args):
|
|
767
796
|
if not isinstance(arg, ScopedValue):
|
|
@@ -781,7 +810,7 @@ class Node:
|
|
|
781
810
|
TypeError: Value of new argument is not an instance of ScopedValue.
|
|
782
811
|
RuntimeError: Length of new arguments is not equal to length of old arguments.
|
|
783
812
|
"""
|
|
784
|
-
Validator.check_int_range(len(kwargs), 0, self._kwargs_num,
|
|
813
|
+
Validator.check_int_range(len(kwargs), 0, self._kwargs_num, Validator.INC_LEFT, "Length of kwargs")
|
|
785
814
|
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
786
815
|
for key, arg in kwargs.items():
|
|
787
816
|
if key not in self._normalized_args.keys() or key not in self._normalized_args_keys:
|
|
@@ -1097,7 +1126,7 @@ class Node:
|
|
|
1097
1126
|
elt.id = scoped_value.value
|
|
1098
1127
|
elif isinstance(elt, ast.Attribute) and isinstance(elt.value, ast.Name):
|
|
1099
1128
|
elt.value.id = scoped_value.scope
|
|
1100
|
-
elt.
|
|
1129
|
+
elt.attr = scoped_value.value
|
|
1101
1130
|
else:
|
|
1102
1131
|
raise RuntimeError("Only support constant or symbol in tuple now")
|
|
1103
1132
|
else:
|
|
@@ -1131,14 +1160,50 @@ class Node:
|
|
|
1131
1160
|
raise RuntimeError("Unsupported return value type: ", return_value_ast)
|
|
1132
1161
|
ast.fix_missing_locations(return_ast)
|
|
1133
1162
|
|
|
1163
|
+
def _sync_mathops_node_args_to_ast(self):
|
|
1164
|
+
"""
|
|
1165
|
+
Sync values from self._normalized_args to the ast node for mathematical operations.
|
|
1166
|
+
"""
|
|
1167
|
+
if self._ast_node is None:
|
|
1168
|
+
return
|
|
1169
|
+
if not isinstance(self._ast_node, ast.Assign):
|
|
1170
|
+
raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}")
|
|
1171
|
+
mathops_node = self._ast_node.value
|
|
1172
|
+
if isinstance(mathops_node, ast.BinOp):
|
|
1173
|
+
left = mathops_node.left
|
|
1174
|
+
right = mathops_node.right
|
|
1175
|
+
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
|
|
1176
|
+
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[1]), right)
|
|
1177
|
+
elif isinstance(mathops_node, ast.UnaryOp):
|
|
1178
|
+
operand = mathops_node.operand
|
|
1179
|
+
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), operand)
|
|
1180
|
+
elif isinstance(mathops_node, ast.BoolOp):
|
|
1181
|
+
values = mathops_node.values
|
|
1182
|
+
for arg_index in range(self._args_num):
|
|
1183
|
+
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
|
1184
|
+
AstModifier.update_arg_value(arg_value, values[arg_index])
|
|
1185
|
+
elif isinstance(mathops_node, ast.Compare):
|
|
1186
|
+
left = mathops_node.left
|
|
1187
|
+
AstModifier.update_arg_value(self._normalized_args.get(self._normalized_args_keys[0]), left)
|
|
1188
|
+
comparators = mathops_node.comparators
|
|
1189
|
+
for arg_index in range(1, self._args_num):
|
|
1190
|
+
arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index])
|
|
1191
|
+
AstModifier.update_arg_value(arg_value, comparators[arg_index - 1])
|
|
1192
|
+
else:
|
|
1193
|
+
raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, "
|
|
1194
|
+
"ast.BoolOp, ast.Compare), but got ", type(mathops_node))
|
|
1195
|
+
|
|
1134
1196
|
def _sync_arg(self):
|
|
1135
1197
|
"""Sync _normalized_args to corresponding ast node when updated."""
|
|
1136
|
-
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree
|
|
1198
|
+
if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree,\
|
|
1199
|
+
NodeType.CellContainer, NodeType.CallFunction):
|
|
1137
1200
|
self._sync_call_cell_args_to_ast()
|
|
1138
1201
|
elif self._node_type == NodeType.Output:
|
|
1139
1202
|
self._sync_return_node_to_ast()
|
|
1140
1203
|
elif self._node_type == NodeType.CallMethod:
|
|
1141
1204
|
self._sync_call_method_args_to_ast()
|
|
1205
|
+
elif self._node_type == NodeType.MathOps:
|
|
1206
|
+
self._sync_mathops_node_args_to_ast()
|
|
1142
1207
|
|
|
1143
1208
|
|
|
1144
1209
|
class TreeNode(Node):
|
|
@@ -1186,8 +1251,6 @@ class TreeNode(Node):
|
|
|
1186
1251
|
instance: Object in network corresponding to this node.
|
|
1187
1252
|
"""
|
|
1188
1253
|
|
|
1189
|
-
if not isinstance(instance, Cell):
|
|
1190
|
-
raise ValueError("Argument instance should be a Cell: ", type(instance))
|
|
1191
1254
|
non_custom_args = Node._handle_custom_obj_in_args(args)
|
|
1192
1255
|
non_custom_kwargs = Node._handle_custom_obj_in_kwargs(kwargs)
|
|
1193
1256
|
new_targets = Node._handle_targets(targets)
|
|
@@ -1196,3 +1259,88 @@ class TreeNode(Node):
|
|
|
1196
1259
|
if ast_node is None:
|
|
1197
1260
|
ast_node = AstModifier.create_call_assign(new_targets, func, non_custom_args, non_custom_kwargs)
|
|
1198
1261
|
return cls(tree, ast_node, new_targets, func, args, kwargs, name, instance)
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
class CellContainer(Node):
|
|
1265
|
+
""" Container for saving cell-objects node. """
|
|
1266
|
+
class _Visitor():
|
|
1267
|
+
""" A iterator of CellContainer nodes. """
|
|
1268
|
+
def __init__(self, cellcontainer):
|
|
1269
|
+
self._cellcontainer = cellcontainer
|
|
1270
|
+
|
|
1271
|
+
def __len__(self):
|
|
1272
|
+
""" Get the number of nodes. """
|
|
1273
|
+
return self._cellcontainer.node_count
|
|
1274
|
+
|
|
1275
|
+
def __iter__(self):
|
|
1276
|
+
"""Create an iterator over the CellContainer."""
|
|
1277
|
+
count = len(self._cellcontainer.node_list)
|
|
1278
|
+
i = 0
|
|
1279
|
+
while i < count:
|
|
1280
|
+
curr = self._cellcontainer.node_list[i]
|
|
1281
|
+
if curr.valid:
|
|
1282
|
+
yield curr
|
|
1283
|
+
i += 1
|
|
1284
|
+
|
|
1285
|
+
def __init__(self, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue,
|
|
1286
|
+
args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance):
|
|
1287
|
+
"""Constructor of CellContainer.
|
|
1288
|
+
|
|
1289
|
+
Args:
|
|
1290
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
1291
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1292
|
+
func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
1293
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1294
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
1295
|
+
name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree.
|
|
1296
|
+
Name of node also used as field name in network class.
|
|
1297
|
+
instance: Object in network corresponding to this node.
|
|
1298
|
+
"""
|
|
1299
|
+
if isinstance(func, str):
|
|
1300
|
+
func = ScopedValue.create_naming_value(func)
|
|
1301
|
+
super().__init__(NodeType.CellContainer, ast_node, targets, func, args, kwargs, name, instance)
|
|
1302
|
+
self._node_list = list()
|
|
1303
|
+
self._node_count = 0
|
|
1304
|
+
|
|
1305
|
+
@property
|
|
1306
|
+
def node_count(self):
|
|
1307
|
+
"""Number of nodes."""
|
|
1308
|
+
return len(self._node_list)
|
|
1309
|
+
|
|
1310
|
+
@property
|
|
1311
|
+
def node_list(self):
|
|
1312
|
+
""" Get node list. """
|
|
1313
|
+
return self._node_list
|
|
1314
|
+
|
|
1315
|
+
def append(self, node):
|
|
1316
|
+
""" Append new node to node list. """
|
|
1317
|
+
setattr(node, "container", self)
|
|
1318
|
+
setattr(node, "valid", True)
|
|
1319
|
+
node.set_belong_symbol_tree(self.get_belong_symbol_tree())
|
|
1320
|
+
self._node_list.append(node)
|
|
1321
|
+
# when creating a cell_container, node instance is already in SequentialCell cell_list
|
|
1322
|
+
# so here we need to write a if judgement
|
|
1323
|
+
if node.get_instance() not in self.get_instance().cell_list:
|
|
1324
|
+
self.get_instance().append(node.get_instance())
|
|
1325
|
+
|
|
1326
|
+
def erase(self, node):
|
|
1327
|
+
"""Erase node form container."""
|
|
1328
|
+
index_node = self.node_list.index(node)
|
|
1329
|
+
index_instance = self.get_instance().cell_list.index(node.get_instance())
|
|
1330
|
+
if index_node != index_instance:
|
|
1331
|
+
raise RuntimeError("In MindSpore Rewrite CellContainer, erasing a node raises index error!!!")
|
|
1332
|
+
setattr(node, "valid", False)
|
|
1333
|
+
del self.get_instance()[index_node]
|
|
1334
|
+
del self._node_list[index_node]
|
|
1335
|
+
|
|
1336
|
+
def insert(self, index, node):
|
|
1337
|
+
"""Insert node into container"""
|
|
1338
|
+
self.node_list.insert(index, node)
|
|
1339
|
+
setattr(node, "container", self)
|
|
1340
|
+
setattr(node, "valid", True)
|
|
1341
|
+
node.set_belong_symbol_tree(self.get_belong_symbol_tree())
|
|
1342
|
+
self.get_instance()._insert(index, node.get_instance())
|
|
1343
|
+
|
|
1344
|
+
def nodes(self):
|
|
1345
|
+
""" Return a iterator of node."""
|
|
1346
|
+
return self._Visitor(self)
|
mindspore/rewrite/parser.py
CHANGED
|
@@ -13,12 +13,12 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Parse ast.arguments to input-node of SymbolTree."""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
16
|
import ast
|
|
18
17
|
|
|
19
18
|
from ..parser import Parser
|
|
20
19
|
from ..parser_register import reg_parser
|
|
21
20
|
from ..symbol_tree import SymbolTree
|
|
21
|
+
from ..common import error_str
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class ArgumentsParser(Parser):
|
|
@@ -44,7 +44,8 @@ class ArgumentsParser(Parser):
|
|
|
44
44
|
|
|
45
45
|
for arg in node.args:
|
|
46
46
|
if not isinstance(arg, ast.arg):
|
|
47
|
-
raise RuntimeError("
|
|
47
|
+
raise RuntimeError(error_str(f"only support ast.arg in arguments arg, but got '{type(arg).__name__}'",
|
|
48
|
+
child_node=arg, father_node=node))
|
|
48
49
|
stree.append_input_node(arg, arg.arg)
|
|
49
50
|
|
|
50
51
|
if hasattr(node, "vararg"):
|