mindspore 1.10.0__cp37-cp37m-win_amd64.whl → 2.0.0rc1__cp37-cp37m-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +9 -4
- mindspore/_c_dataengine.cp37-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp37-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp37-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +102 -0
- mindspore/_checkparam.py +1066 -1001
- mindspore/_extends/builtin_operations.py +32 -4
- mindspore/_extends/graph_kernel/model/graph_split.py +66 -222
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +12 -9
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +119 -26
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +50 -50
- mindspore/_extends/parallel_compile/akg_compiler/util.py +9 -6
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +4 -25
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +9 -4
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -27
- mindspore/_extends/parse/__init__.py +5 -3
- mindspore/_extends/parse/namespace.py +17 -2
- mindspore/_extends/parse/parser.py +193 -34
- mindspore/_extends/parse/resources.py +7 -8
- mindspore/_extends/parse/standard_method.py +1780 -435
- mindspore/_extends/parse/trope.py +3 -1
- mindspore/amp.py +53 -58
- mindspore/atlprov.dll +0 -0
- mindspore/boost/adasum.py +3 -2
- mindspore/boost/boost.py +2 -2
- mindspore/boost/boost_cell_wrapper.py +46 -26
- mindspore/boost/dim_reduce.py +6 -5
- mindspore/boost/grad_accumulation.py +2 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +11 -10
- mindspore/common/_decorator.py +2 -0
- mindspore/common/_register_for_adapter.py +55 -0
- mindspore/common/_stub_tensor.py +201 -0
- mindspore/common/_utils.py +57 -0
- mindspore/common/api.py +582 -297
- mindspore/common/dtype.py +66 -18
- mindspore/common/dump.py +2 -2
- mindspore/common/initializer.py +38 -1
- mindspore/common/jit_config.py +25 -13
- mindspore/common/mutable.py +53 -24
- mindspore/common/parameter.py +60 -37
- mindspore/common/seed.py +8 -24
- mindspore/common/sparse_tensor.py +927 -0
- mindspore/common/tensor.py +1627 -3900
- mindspore/communication/__init__.py +10 -5
- mindspore/communication/_comm_helper.py +78 -214
- mindspore/communication/_hccl_management.py +2 -1
- mindspore/communication/management.py +136 -47
- mindspore/config/op_info.config +501 -1008
- mindspore/context.py +291 -56
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +12 -8
- mindspore/dataset/audio/__init__.py +9 -9
- mindspore/dataset/audio/transforms.py +1090 -228
- mindspore/dataset/audio/utils.py +87 -39
- mindspore/dataset/audio/validators.py +223 -1
- mindspore/dataset/callback/ds_callback.py +17 -15
- mindspore/dataset/core/config.py +246 -17
- mindspore/dataset/core/py_util_helpers.py +4 -3
- mindspore/dataset/core/validator_helpers.py +10 -10
- mindspore/{parallel/nn/layers.py → dataset/debug/__init__.py} +7 -8
- mindspore/dataset/debug/debug_hook.py +65 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +7 -3
- mindspore/dataset/engine/cache_client.py +9 -9
- mindspore/dataset/engine/datasets.py +648 -477
- mindspore/dataset/engine/datasets_audio.py +165 -167
- mindspore/dataset/engine/datasets_standard_format.py +93 -67
- mindspore/dataset/engine/datasets_text.py +492 -342
- mindspore/dataset/engine/datasets_user_defined.py +85 -50
- mindspore/dataset/engine/datasets_vision.py +1224 -699
- mindspore/dataset/engine/graphdata.py +134 -69
- mindspore/dataset/engine/iterators.py +50 -9
- mindspore/dataset/engine/offload.py +52 -31
- mindspore/dataset/engine/samplers.py +27 -24
- mindspore/dataset/engine/serializer_deserializer.py +14 -15
- mindspore/dataset/engine/validators.py +213 -52
- mindspore/dataset/text/__init__.py +10 -8
- mindspore/dataset/text/transforms.py +152 -57
- mindspore/dataset/text/utils.py +98 -49
- mindspore/dataset/text/validators.py +25 -0
- mindspore/dataset/transforms/__init__.py +4 -2
- mindspore/dataset/transforms/c_transforms.py +11 -13
- mindspore/dataset/transforms/py_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms_util.py +10 -0
- mindspore/dataset/transforms/transforms.py +13 -15
- mindspore/dataset/transforms/validators.py +7 -7
- mindspore/dataset/utils/__init__.py +2 -1
- mindspore/dataset/utils/browse_dataset.py +13 -13
- mindspore/dataset/utils/line_reader.py +121 -0
- mindspore/dataset/vision/__init__.py +8 -7
- mindspore/dataset/vision/c_transforms.py +125 -126
- mindspore/dataset/vision/py_transforms.py +37 -37
- mindspore/dataset/vision/py_transforms_util.py +23 -20
- mindspore/dataset/vision/transforms.py +316 -315
- mindspore/dataset/vision/utils.py +313 -17
- mindspore/dataset/vision/validators.py +6 -6
- mindspore/default_config.py +0 -1
- mindspore/dpcmi.dll +0 -0
- mindspore/{compression → experimental}/__init__.py +6 -5
- mindspore/experimental/map_parameter.py +275 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/callback/callback.h +9 -13
- mindspore/include/api/callback/ckpt_saver.h +2 -2
- mindspore/include/api/callback/loss_monitor.h +2 -2
- mindspore/include/api/callback/lr_scheduler.h +5 -5
- mindspore/include/api/callback/time_monitor.h +2 -2
- mindspore/include/api/callback/train_accuracy.h +4 -6
- mindspore/include/api/cfg.h +19 -6
- mindspore/include/api/context.h +70 -9
- mindspore/include/api/delegate.h +8 -1
- mindspore/include/api/dual_abi_helper.h +8 -24
- mindspore/include/api/metrics/accuracy.h +2 -2
- mindspore/include/api/metrics/metrics.h +4 -3
- mindspore/include/api/model.h +9 -4
- mindspore/include/api/model_group.h +68 -0
- mindspore/include/api/model_parallel_runner.h +17 -17
- mindspore/include/api/net.h +12 -11
- mindspore/include/api/serialization.h +20 -4
- mindspore/include/api/status.h +7 -1
- mindspore/include/api/types.h +25 -21
- mindspore/include/api/visible.h +4 -0
- mindspore/include/c_api/model_c.h +5 -0
- mindspore/include/c_api/status_c.h +1 -1
- mindspore/include/dataset/config.h +1 -1
- mindspore/include/dataset/constants.h +14 -0
- mindspore/include/dataset/text.h +59 -0
- mindspore/include/dataset/vision.h +56 -117
- mindspore/include/dataset/vision_lite.h +102 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +28 -28
- mindspore/mindrecord/common/exceptions.py +2 -4
- mindspore/mindrecord/filereader.py +19 -1
- mindspore/mindrecord/filewriter.py +250 -88
- mindspore/mindrecord/mindpage.py +13 -13
- mindspore/mindrecord/shardheader.py +15 -15
- mindspore/mindrecord/shardreader.py +9 -0
- mindspore/mindrecord/shardwriter.py +29 -29
- mindspore/mindrecord/tools/cifar100_to_mr.py +9 -9
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -9
- mindspore/mindrecord/tools/csv_to_mr.py +4 -4
- mindspore/mindrecord/tools/imagenet_to_mr.py +70 -65
- mindspore/mindrecord/tools/mnist_to_mr.py +41 -41
- mindspore/mindrecord/tools/tfrecord_to_mr.py +6 -6
- mindspore/{libmindspore_backend.dll → mindspore_backend.dll} +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +1 -5
- mindspore/nn/cell.py +297 -234
- mindspore/nn/dynamic_lr.py +1 -1
- mindspore/nn/grad/cell_grad.py +17 -42
- mindspore/nn/layer/__init__.py +7 -4
- mindspore/nn/layer/activation.py +131 -88
- mindspore/nn/layer/basic.py +313 -613
- mindspore/nn/layer/channel_shuffle.py +103 -0
- mindspore/nn/layer/combined.py +1 -1
- mindspore/nn/layer/container.py +52 -6
- mindspore/nn/layer/conv.py +112 -43
- mindspore/nn/layer/dense.py +10 -9
- mindspore/nn/layer/embedding.py +36 -34
- mindspore/nn/layer/image.py +123 -27
- mindspore/nn/layer/math.py +108 -107
- mindspore/nn/layer/normalization.py +212 -366
- mindspore/nn/layer/padding.py +370 -42
- mindspore/nn/layer/pooling.py +1443 -219
- mindspore/nn/layer/rnn_cells.py +11 -16
- mindspore/nn/layer/rnns.py +38 -39
- mindspore/nn/layer/thor_layer.py +24 -25
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +701 -0
- mindspore/nn/learning_rate_schedule.py +8 -8
- mindspore/nn/loss/__init__.py +9 -6
- mindspore/nn/loss/loss.py +678 -142
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +2 -2
- mindspore/nn/optim/ada_grad.py +8 -8
- mindspore/nn/optim/adadelta.py +2 -3
- mindspore/nn/optim/adafactor.py +18 -14
- mindspore/nn/optim/adam.py +429 -87
- mindspore/nn/optim/adamax.py +5 -6
- mindspore/nn/optim/adasum.py +10 -8
- mindspore/nn/optim/asgd.py +7 -7
- mindspore/nn/optim/ftrl.py +81 -11
- mindspore/nn/optim/lamb.py +7 -8
- mindspore/nn/optim/lars.py +4 -4
- mindspore/nn/optim/lazyadam.py +82 -7
- mindspore/nn/optim/momentum.py +8 -7
- mindspore/nn/optim/optimizer.py +19 -10
- mindspore/nn/optim/proximal_ada_grad.py +6 -5
- mindspore/nn/optim/rmsprop.py +3 -3
- mindspore/nn/optim/rprop.py +20 -16
- mindspore/nn/optim/sgd.py +21 -15
- mindspore/nn/optim/thor.py +23 -21
- mindspore/nn/probability/__init__.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -6
- mindspore/nn/probability/bijector/invert.py +4 -2
- mindspore/nn/probability/bijector/softplus.py +2 -2
- mindspore/nn/probability/bnn_layers/dense_variational.py +1 -1
- mindspore/nn/probability/bnn_layers/layer_distribution.py +2 -2
- mindspore/nn/probability/distribution/__init__.py +6 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +11 -17
- mindspore/nn/probability/distribution/bernoulli.py +6 -6
- mindspore/nn/probability/distribution/beta.py +1 -1
- mindspore/nn/probability/distribution/categorical.py +9 -9
- mindspore/nn/probability/distribution/cauchy.py +8 -8
- mindspore/nn/probability/distribution/distribution.py +12 -6
- mindspore/nn/probability/distribution/exponential.py +5 -5
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +6 -5
- mindspore/nn/probability/distribution/gumbel.py +5 -5
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +0 -1
- mindspore/nn/probability/distribution/logistic.py +4 -5
- mindspore/nn/probability/distribution/normal.py +11 -15
- mindspore/nn/probability/distribution/poisson.py +6 -2
- mindspore/nn/probability/distribution/student_t.py +150 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +4 -4
- mindspore/nn/probability/distribution/uniform.py +5 -5
- mindspore/nn/reinforcement/_tensors_queue.py +3 -3
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +8 -1
- mindspore/nn/wrap/cell_wrapper.py +55 -27
- mindspore/nn/wrap/grad_reducer.py +20 -11
- mindspore/nn/wrap/loss_scale.py +47 -30
- mindspore/numpy/array_creations.py +33 -22
- mindspore/numpy/array_ops.py +46 -42
- mindspore/numpy/logic_ops.py +6 -27
- mindspore/numpy/math_ops.py +26 -19
- mindspore/numpy/utils.py +1 -8
- mindspore/numpy/utils_const.py +112 -62
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -3
- mindspore/ops/_constants.py +0 -6
- mindspore/ops/_grad/__init__.py +2 -1
- mindspore/ops/_grad/grad_array_ops.py +209 -152
- mindspore/ops/_grad/grad_base.py +55 -17
- mindspore/ops/_grad/grad_clip_ops.py +11 -3
- mindspore/ops/_grad/grad_comm_ops.py +58 -47
- mindspore/ops/_grad/grad_implementations.py +21 -61
- mindspore/ops/_grad/grad_inner_ops.py +48 -6
- mindspore/ops/_grad/grad_math_ops.py +306 -161
- mindspore/ops/_grad/grad_nn_ops.py +192 -181
- mindspore/ops/_grad/grad_other_ops.py +1 -1
- mindspore/ops/_grad/grad_quant_ops.py +5 -5
- mindspore/ops/_grad/grad_sequence_ops.py +296 -0
- mindspore/ops/_grad/grad_sparse.py +15 -9
- mindspore/ops/_grad_experimental/__init__.py +1 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +441 -55
- mindspore/ops/_grad_experimental/grad_image_ops.py +25 -7
- mindspore/ops/_grad_experimental/grad_inner_ops.py +3 -44
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +16 -21
- mindspore/ops/_grad_experimental/grad_math_ops.py +979 -49
- mindspore/ops/_grad_experimental/grad_nn_ops.py +78 -8
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +112 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +197 -13
- mindspore/ops/_op_impl/__init__.py +3 -3
- mindspore/ops/_op_impl/_custom_op/__init__.py +0 -1
- mindspore/ops/_op_impl/_custom_op/_basic.py +0 -1
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +4 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +5 -5
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +3 -3
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +3 -3
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +4 -8
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +2 -2
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +2 -2
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +0 -1
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +2 -2
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +2 -2
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +238 -3
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +0 -32
- mindspore/ops/_op_impl/aicpu/addcmul.py +0 -84
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/{nn/probability/infer/variational/__init__.py → ops/_op_impl/aicpu/cauchy.py} +17 -10
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +1 -1
- mindspore/ops/_op_impl/{cpu/bias_add.py → aicpu/choleskygrad.py} +9 -7
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +11 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +2 -2
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/{cpu/bias_add_grad.py → aicpu/digamma.py} +9 -7
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +41 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +1 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +1 -1
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/{tbe/scatter_add_ds.py → aicpu/inplace_index_add.py} +17 -21
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +39 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +32 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +2 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +0 -4
- mindspore/ops/_op_impl/aicpu/mul.py +3 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +14 -6
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +0 -2
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +3 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +2 -8
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +1 -1
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +4 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +2 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +12 -6
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +0 -24
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/{tbe/slice_ds.py → aicpu/sparse_segment_sum.py} +16 -24
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +2 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +93 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +66 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/{tbe/gather_v2.py → aicpu/tile.py} +24 -24
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +34 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +1 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +10 -2
- mindspore/ops/_op_impl/cpu/__init__.py +1 -2
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +5 -1
- mindspore/ops/_op_impl/cpu/maximum_grad.py +2 -0
- mindspore/{compression/common/__init__.py → ops/_op_impl/cpu/pyexecute.py} +13 -8
- mindspore/ops/_op_impl/cpu/reduce_sum.py +8 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/tensor_shape.py +5 -1
- mindspore/ops/_op_impl/tbe/__init__.py +27 -608
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +42 -0
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +1 -1
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +41 -0
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +1 -0
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +2 -0
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +4 -2
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +40 -0
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -1
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -1
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +6 -4
- mindspore/ops/_op_impl/tbe/cast.py +0 -2
- mindspore/ops/_op_impl/tbe/cast_ds.py +3 -3
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -2
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -2
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +1 -0
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +1 -0
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +1 -1
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_nd.py +1 -0
- mindspore/ops/_op_impl/tbe/greater.py +2 -0
- mindspore/ops/_op_impl/tbe/{index_add.py → inplace_index_add.py} +3 -6
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -1
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +35 -0
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -6
- mindspore/ops/_op_impl/tbe/{greater_ds.py → reduce_all_ds.py} +13 -16
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +39 -0
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +44 -0
- mindspore/ops/_op_impl/tbe/scatter_add.py +2 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +2 -2
- mindspore/ops/_op_impl/tbe/slice.py +26 -15
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +1 -1
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +1 -0
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +15 -5
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +1 -1
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +3 -2
- mindspore/ops/_register_for_op.py +11 -0
- mindspore/ops/_utils/__init__.py +1 -1
- mindspore/ops/_utils/utils.py +20 -41
- mindspore/ops/_vmap/__init__.py +2 -2
- mindspore/ops/_vmap/vmap_array_ops.py +170 -78
- mindspore/ops/_vmap/vmap_base.py +24 -10
- mindspore/ops/_vmap/vmap_convolution_ops.py +7 -10
- mindspore/ops/_vmap/vmap_grad_math_ops.py +4 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -9
- mindspore/ops/_vmap/vmap_image_ops.py +52 -0
- mindspore/ops/_vmap/vmap_math_ops.py +77 -6
- mindspore/ops/_vmap/vmap_nn_ops.py +78 -29
- mindspore/ops/_vmap/vmap_other_ops.py +3 -1
- mindspore/ops/_vmap/vmap_random_ops.py +55 -3
- mindspore/ops/_vmap/vmap_sparse_ops.py +1 -0
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +16 -16
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +150 -0
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +306 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +12 -8
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +240 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +247 -0
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +315 -0
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +278 -0
- mindspore/ops/bprop_mindir/DType_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +138 -0
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +22 -24
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +16 -14
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +58 -0
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +57 -0
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +13 -10
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +14 -11
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +126 -0
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +15 -0
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +30 -0
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +43 -0
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +22 -19
- mindspore/ops/bprop_mindir/Load_bprop.mindir +12 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +17 -18
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +74 -0
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +75 -0
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +65 -0
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +27 -0
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +35 -0
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +18 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +25 -23
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +13 -13
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +82 -0
- mindspore/ops/bprop_mindir/Range_bprop.mindir +21 -19
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +11 -11
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +18 -17
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +19 -23
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +60 -0
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +89 -0
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +52 -0
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Round_bprop.mindir +14 -13
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +24 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/Select_bprop.mindir +30 -34
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +12 -12
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +21 -0
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +26 -0
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +36 -0
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +33 -0
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +28 -0
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +54 -0
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +95 -0
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +98 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +28 -32
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +66 -0
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +29 -0
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +14 -0
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +23 -0
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +18 -15
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +11 -13
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +16 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +22 -0
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +32 -0
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +38 -0
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +13 -12
- mindspore/ops/bprop_mindir/__init__.py +1 -4
- mindspore/ops/bprop_mindir/generate_mindir.py +32 -20
- mindspore/ops/composite/__init__.py +12 -13
- mindspore/ops/composite/base.py +261 -254
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +197 -156
- mindspore/ops/composite/multitype_ops/_compile_utils.py +428 -176
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +188 -87
- mindspore/ops/composite/multitype_ops/add_impl.py +23 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/equal_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +52 -5
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +31 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +33 -2
- mindspore/ops/composite/multitype_ops/less_impl.py +33 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +21 -7
- mindspore/ops/composite/multitype_ops/not_in_impl.py +15 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -4
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +62 -70
- mindspore/ops/composite/multitype_ops/sub_impl.py +3 -3
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +41 -4
- mindspore/ops/function/__init__.py +323 -8
- mindspore/ops/function/array_func.py +3511 -780
- mindspore/ops/function/clip_func.py +329 -0
- mindspore/ops/function/debug_func.py +6 -6
- mindspore/ops/function/grad/__init__.py +5 -1
- mindspore/ops/function/grad/grad_func.py +736 -65
- mindspore/ops/function/image_func.py +270 -0
- mindspore/ops/function/linalg_func.py +268 -8
- mindspore/ops/function/math_func.py +8032 -3164
- mindspore/ops/function/nn_func.py +5619 -1855
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +11 -10
- mindspore/ops/function/random_func.py +939 -77
- mindspore/ops/function/sparse_func.py +249 -84
- mindspore/ops/function/sparse_unary_func.py +2303 -0
- mindspore/ops/function/spectral_func.py +146 -0
- mindspore/ops/function/vmap_func.py +114 -0
- mindspore/ops/functional.py +182 -254
- mindspore/ops/op_info_register.py +79 -34
- mindspore/ops/operations/__init__.py +210 -118
- mindspore/ops/operations/_csr_ops.py +7 -7
- mindspore/ops/operations/_embedding_cache_ops.py +25 -15
- mindspore/ops/operations/_grad_ops.py +447 -322
- mindspore/ops/operations/_inner_ops.py +547 -176
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +29 -27
- mindspore/ops/operations/_ocr_ops.py +11 -11
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_quant_ops.py +186 -101
- mindspore/ops/operations/_rl_inner_ops.py +122 -61
- mindspore/ops/operations/_scalar_ops.py +466 -0
- mindspore/ops/operations/_sequence_ops.py +1047 -0
- mindspore/ops/operations/_tensor_array.py +10 -11
- mindspore/ops/operations/_thor_ops.py +4 -4
- mindspore/ops/operations/array_ops.py +1428 -1226
- mindspore/ops/operations/comm_ops.py +180 -117
- mindspore/ops/operations/control_ops.py +4 -2
- mindspore/ops/operations/custom_ops.py +185 -98
- mindspore/ops/operations/debug_ops.py +92 -54
- mindspore/ops/operations/image_ops.py +406 -211
- mindspore/ops/operations/inner_ops.py +42 -53
- mindspore/ops/operations/linalg_ops.py +32 -29
- mindspore/ops/operations/math_ops.py +2076 -897
- mindspore/ops/operations/nn_ops.py +1282 -1252
- mindspore/ops/operations/other_ops.py +124 -278
- mindspore/ops/operations/random_ops.py +345 -178
- mindspore/ops/operations/rl_ops.py +8 -9
- mindspore/ops/operations/sparse_ops.py +502 -157
- mindspore/ops/operations/spectral_ops.py +107 -0
- mindspore/ops/primitive.py +192 -15
- mindspore/ops/vm_impl_registry.py +23 -2
- mindspore/parallel/__init__.py +6 -1
- mindspore/parallel/_auto_parallel_context.py +199 -92
- mindspore/parallel/_cell_wrapper.py +4 -2
- mindspore/parallel/_cost_model_context.py +3 -0
- mindspore/parallel/_dp_allreduce_fusion.py +2 -1
- mindspore/parallel/_offload_context.py +185 -0
- mindspore/parallel/_parallel_serialization.py +167 -28
- mindspore/parallel/_ps_context.py +9 -5
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +9 -1
- mindspore/{nn/transformer → parallel/_transformer}/__init__.py +6 -6
- mindspore/{nn/transformer → parallel/_transformer}/layers.py +59 -37
- mindspore/{nn/transformer → parallel/_transformer}/loss.py +4 -7
- mindspore/{nn/transformer → parallel/_transformer}/moe.py +160 -35
- mindspore/{nn/transformer → parallel/_transformer}/op_parallel_config.py +3 -3
- mindspore/{nn/transformer → parallel/_transformer}/transformer.py +235 -196
- mindspore/parallel/_utils.py +47 -7
- mindspore/parallel/algo_parameter_config.py +5 -1
- mindspore/parallel/checkpoint_transform.py +329 -0
- mindspore/parallel/shard.py +229 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/util.py +4 -3
- mindspore/profiler/common/validator/validate_path.py +2 -2
- mindspore/profiler/envprofiling.py +249 -0
- mindspore/profiler/parser/aicpu_data_parser.py +38 -39
- mindspore/profiler/parser/ascend_timeline_generator.py +497 -0
- mindspore/profiler/parser/base_timeline_generator.py +471 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +684 -0
- mindspore/profiler/parser/framework_parser.py +42 -16
- mindspore/profiler/parser/hccl_parser.py +158 -158
- mindspore/profiler/parser/hwts_log_parser.py +7 -6
- mindspore/profiler/parser/integrator.py +18 -1579
- mindspore/profiler/parser/minddata_analyzer.py +8 -8
- mindspore/profiler/parser/msadvisor_analyzer.py +14 -27
- mindspore/profiler/parser/msadvisor_parser.py +2 -4
- mindspore/profiler/parser/optime_parser.py +17 -18
- mindspore/profiler/parser/profiler_info.py +108 -0
- mindspore/profiler/parser/step_trace_parser.py +1 -1
- mindspore/profiler/profiling.py +396 -194
- mindspore/rewrite/__init__.py +6 -2
- mindspore/rewrite/api/node.py +51 -110
- mindspore/rewrite/api/node_type.py +10 -6
- mindspore/rewrite/api/pattern_engine.py +51 -7
- mindspore/rewrite/api/scoped_value.py +64 -53
- mindspore/rewrite/api/symbol_tree.py +108 -61
- mindspore/rewrite/api/tree_node_helper.py +2 -3
- mindspore/{compression/quant/__init__.py → rewrite/ast_creator_register.py} +20 -11
- mindspore/rewrite/ast_helpers/__init__.py +6 -3
- mindspore/rewrite/ast_helpers/ast_creator.py +115 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +99 -1
- mindspore/rewrite/ast_helpers/ast_modifier.py +17 -4
- mindspore/rewrite/ast_helpers/ast_replacer.py +1 -1
- mindspore/rewrite/ast_transformers/__init__.py +0 -1
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +46 -5
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +6 -3
- mindspore/rewrite/common/__init__.py +2 -0
- mindspore/rewrite/common/event.py +1 -1
- mindspore/rewrite/common/observable.py +1 -1
- mindspore/rewrite/common/observer.py +1 -1
- mindspore/rewrite/common/rewrite_elog.py +35 -0
- mindspore/rewrite/namer.py +2 -2
- mindspore/rewrite/namespace.py +14 -4
- mindspore/rewrite/node.py +161 -13
- mindspore/rewrite/parser.py +0 -1
- mindspore/rewrite/parser_register.py +0 -1
- mindspore/rewrite/parsers/arguments_parser.py +3 -2
- mindspore/rewrite/parsers/assign_parser.py +267 -67
- mindspore/rewrite/parsers/attribute_parser.py +56 -0
- mindspore/rewrite/parsers/class_def_parser.py +191 -108
- mindspore/rewrite/parsers/constant_parser.py +101 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/for_parser.py +28 -15
- mindspore/rewrite/parsers/function_def_parser.py +21 -5
- mindspore/rewrite/parsers/if_parser.py +11 -28
- mindspore/rewrite/parsers/module_parser.py +9 -6
- mindspore/rewrite/parsers/return_parser.py +3 -2
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +448 -0
- mindspore/rewrite/sparsify/sparsify.py +109 -0
- mindspore/rewrite/sparsify/utils.py +173 -0
- mindspore/rewrite/symbol_tree.py +322 -109
- mindspore/rewrite/symbol_tree_builder.py +45 -8
- mindspore/rewrite/symbol_tree_dumper.py +0 -1
- mindspore/rewrite/topological_manager.py +1 -2
- mindspore/run_check/_check_version.py +209 -112
- mindspore/run_check/run_check.py +2 -1
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +6 -4
- mindspore/train/_utils.py +28 -5
- mindspore/train/amp.py +321 -50
- mindspore/train/callback/__init__.py +3 -1
- mindspore/train/callback/_backup_and_restore.py +120 -0
- mindspore/train/callback/_callback.py +8 -8
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_early_stop.py +13 -7
- mindspore/train/callback/_history.py +8 -8
- mindspore/train/callback/_lambda_callback.py +6 -6
- mindspore/train/callback/_landscape.py +36 -38
- mindspore/train/callback/_loss_monitor.py +12 -6
- mindspore/train/callback/_lr_scheduler_callback.py +2 -4
- mindspore/train/callback/_on_request_exit.py +212 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +13 -7
- mindspore/train/callback/_summary_collector.py +27 -19
- mindspore/train/callback/_time_monitor.py +13 -7
- mindspore/train/checkpoint_pb2.py +68 -8
- mindspore/train/data_sink.py +122 -33
- mindspore/train/dataset_helper.py +28 -87
- mindspore/train/loss_scale_manager.py +4 -7
- mindspore/{nn → train}/metrics/__init__.py +20 -20
- mindspore/{nn → train}/metrics/accuracy.py +12 -10
- mindspore/{nn → train}/metrics/auc.py +4 -4
- mindspore/{nn → train}/metrics/bleu_score.py +4 -4
- mindspore/{nn → train}/metrics/confusion_matrix.py +10 -8
- mindspore/{nn → train}/metrics/cosine_similarity.py +4 -4
- mindspore/{nn → train}/metrics/dice.py +6 -5
- mindspore/{nn → train}/metrics/error.py +7 -5
- mindspore/{nn → train}/metrics/fbeta.py +9 -7
- mindspore/{nn → train}/metrics/hausdorff_distance.py +8 -6
- mindspore/{nn → train}/metrics/loss.py +4 -3
- mindspore/{nn → train}/metrics/mean_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/metric.py +6 -5
- mindspore/{nn → train}/metrics/occlusion_sensitivity.py +4 -3
- mindspore/{nn → train}/metrics/perplexity.py +5 -4
- mindspore/{nn → train}/metrics/precision.py +5 -4
- mindspore/{nn → train}/metrics/recall.py +5 -4
- mindspore/{nn → train}/metrics/roc.py +7 -6
- mindspore/{nn → train}/metrics/root_mean_square_surface_distance.py +6 -5
- mindspore/{nn → train}/metrics/topk.py +7 -5
- mindspore/train/mind_ir_pb2.py +339 -32
- mindspore/train/model.py +113 -84
- mindspore/train/serialization.py +547 -167
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +43 -12
- mindspore/train/train_thor/convert_utils.py +7 -1
- mindspore/train/train_thor/dataset_helper.py +3 -3
- mindspore/train/train_thor/model_thor.py +0 -4
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/METADATA +4 -3
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/RECORD +901 -660
- mindspore/compression/common/constant.py +0 -124
- mindspore/compression/export/__init__.py +0 -19
- mindspore/compression/export/quant_export.py +0 -514
- mindspore/compression/quant/qat.py +0 -636
- mindspore/compression/quant/quant_utils.py +0 -462
- mindspore/compression/quant/quantizer.py +0 -68
- mindspore/libatomic-1.dll +0 -0
- mindspore/libgcc_s_seh-1.dll +0 -0
- mindspore/libgfortran-4.dll +0 -0
- mindspore/libgomp-1.dll +0 -0
- mindspore/libjpeg-62.dll +0 -0
- mindspore/libmindspore.dll +0 -0
- mindspore/libmindspore_common.dll +0 -0
- mindspore/libmindspore_core.dll +0 -0
- mindspore/libmindspore_glog.dll +0 -0
- mindspore/libnnacl.dll +0 -0
- mindspore/libopencv_core452.dll +0 -0
- mindspore/libopencv_imgcodecs452.dll +0 -0
- mindspore/libopencv_imgproc452.dll +0 -0
- mindspore/libquadmath-0.dll +0 -0
- mindspore/libsqlite3.dll +0 -0
- mindspore/libssp-0.dll +0 -0
- mindspore/libstdc++-6.dll +0 -0
- mindspore/libtinyxml2.dll +0 -0
- mindspore/libturbojpeg.dll +0 -0
- mindspore/libwinpthread-1.dll +0 -0
- mindspore/nn/layer/quant.py +0 -1868
- mindspore/nn/layer/rnn_utils.py +0 -90
- mindspore/nn/probability/dpn/__init__.py +0 -22
- mindspore/nn/probability/dpn/vae/__init__.py +0 -25
- mindspore/nn/probability/dpn/vae/cvae.py +0 -138
- mindspore/nn/probability/dpn/vae/vae.py +0 -122
- mindspore/nn/probability/infer/__init__.py +0 -22
- mindspore/nn/probability/infer/variational/elbo.py +0 -70
- mindspore/nn/probability/infer/variational/svi.py +0 -84
- mindspore/nn/probability/toolbox/__init__.py +0 -22
- mindspore/nn/probability/toolbox/anomaly_detection.py +0 -99
- mindspore/nn/probability/toolbox/uncertainty_evaluation.py +0 -363
- mindspore/nn/probability/transforms/__init__.py +0 -22
- mindspore/nn/probability/transforms/transform_bnn.py +0 -262
- mindspore/nn/probability/zhusuan/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/framework/bn.py +0 -95
- mindspore/nn/probability/zhusuan/variational/__init__.py +0 -18
- mindspore/nn/probability/zhusuan/variational/elbo.py +0 -46
- mindspore/ops/_op_impl/tbe/bias_add_grad_ds.py +0 -52
- mindspore/ops/_op_impl/tbe/scatter_nd_add_ds.py +0 -43
- mindspore/ops/bprop_mindir/AssignAdd_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Identity_bprop.mindir +0 -9
- mindspore/ops/bprop_mindir/LogicalOr_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/ReLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/UpdateState_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/stop_gradient_bprop.mindir +0 -12
- mindspore/ops/composite/array_ops.py +0 -210
- mindspore/ops/composite/clip_ops.py +0 -238
- mindspore/ops/composite/random_ops.py +0 -426
- mindspore/ops/composite/vmap_ops.py +0 -38
- mindspore/ops/operations/sponge_ops.py +0 -3531
- mindspore/ops/operations/sponge_update_ops.py +0 -2546
- mindspore/parallel/nn/__init__.py +0 -42
- mindspore/parallel/nn/loss.py +0 -22
- mindspore/parallel/nn/moe.py +0 -21
- mindspore/parallel/nn/op_parallel_config.py +0 -22
- mindspore/parallel/nn/transformer.py +0 -31
- mindspore/run_check/_check_deps_version.py +0 -84
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-1.10.0.dist-info → mindspore-2.0.0rc1.dist-info}/top_level.txt +0 -0
mindspore/rewrite/symbol_tree.py
CHANGED
|
@@ -13,22 +13,22 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""SymbolTree class define of Rewrite according to forward function of a network."""
|
|
16
|
-
from __future__ import absolute_import
|
|
17
16
|
import stat
|
|
18
17
|
from typing import Optional, Union, Tuple, Any
|
|
19
18
|
import os
|
|
20
19
|
import sys
|
|
21
20
|
import ast
|
|
22
|
-
import tempfile
|
|
23
21
|
import importlib
|
|
24
|
-
|
|
22
|
+
import types
|
|
23
|
+
import time
|
|
25
24
|
import astunparse
|
|
26
25
|
|
|
27
26
|
from mindspore.nn import Cell
|
|
28
27
|
from mindspore import log as logger
|
|
29
|
-
from .
|
|
28
|
+
from mindspore.rewrite.ast_creator_register import ast_creator_registry
|
|
29
|
+
from .node import Node, TreeNode
|
|
30
30
|
from .api.node_type import NodeType
|
|
31
|
-
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder
|
|
31
|
+
from .ast_helpers import AstModifier, AstReplacer, StrChecker, AstFinder, CheckPropertyIsUsed
|
|
32
32
|
from .api.scoped_value import ScopedValue, ValueType
|
|
33
33
|
from .symbol_tree_dumper import SymbolTreeDumper
|
|
34
34
|
from .topological_manager import TopoManager
|
|
@@ -159,7 +159,6 @@ class SymbolTree(Observer, Observable):
|
|
|
159
159
|
self._topo_mgr = TopoManager()
|
|
160
160
|
self._topo_mgr.reg_observer(self)
|
|
161
161
|
|
|
162
|
-
self._global_vars: {str, object} = {origin_network_key: origin_network}
|
|
163
162
|
self._nodes: {str, Node} = {}
|
|
164
163
|
# parameters of forward method
|
|
165
164
|
self._inputs: [Node] = []
|
|
@@ -170,6 +169,10 @@ class SymbolTree(Observer, Observable):
|
|
|
170
169
|
self._class_ast: Optional[ast.ClassDef] = None
|
|
171
170
|
self._root_ast: Optional[ast.FunctionDef] = None
|
|
172
171
|
self._init_func_ast: Optional[ast.FunctionDef] = None
|
|
172
|
+
self._deleted_field = {}
|
|
173
|
+
self._deleted_node = []
|
|
174
|
+
self._external_func_ast = []
|
|
175
|
+
self._father_class_ast = []
|
|
173
176
|
|
|
174
177
|
# head node is always point to the first node(in source code order) of SymbolTree
|
|
175
178
|
self._head = None
|
|
@@ -198,12 +201,12 @@ class SymbolTree(Observer, Observable):
|
|
|
198
201
|
for node in nodes:
|
|
199
202
|
for arg in node.get_args():
|
|
200
203
|
if consumers.get(arg):
|
|
201
|
-
consumers
|
|
204
|
+
consumers[arg].append(node)
|
|
202
205
|
else:
|
|
203
206
|
consumers[arg] = [node]
|
|
204
207
|
for _, arg in node.get_kwargs():
|
|
205
208
|
if consumers.get(arg):
|
|
206
|
-
consumers
|
|
209
|
+
consumers[arg].append(node)
|
|
207
210
|
else:
|
|
208
211
|
consumers[arg] = [node]
|
|
209
212
|
for target in node.get_targets():
|
|
@@ -262,6 +265,8 @@ class SymbolTree(Observer, Observable):
|
|
|
262
265
|
for node in stree.nodes():
|
|
263
266
|
if not isinstance(node, TreeNode):
|
|
264
267
|
continue
|
|
268
|
+
if node.symbol_tree._class_ast is None:
|
|
269
|
+
continue
|
|
265
270
|
sub_stree: SymbolTree = node.symbol_tree
|
|
266
271
|
SymbolTree._find_all_class_in_symboltree(sub_stree, seen_class, allow_class_name, replacers)
|
|
267
272
|
# all modified ast.ClassDef should export to code
|
|
@@ -280,6 +285,89 @@ class SymbolTree(Observer, Observable):
|
|
|
280
285
|
"""Add Event.TopologicalChangeEvent event when build is finished."""
|
|
281
286
|
self.add_event(Event.TopologicalChangeEvent)
|
|
282
287
|
|
|
288
|
+
def _create_call_function(self, func, targets, args, kwargs):
|
|
289
|
+
"""
|
|
290
|
+
Create a Node object and generate the execution code to insert into the source code.
|
|
291
|
+
The source code calls the 'func' function with 'args' and' kwargs' as parameters.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
func (FunctionType) - The function to be called.
|
|
295
|
+
targets (list [str]) - indicates the output name. As the output of the node in the source code.
|
|
296
|
+
args (ParamType) - parameter name of the node. Used as a parameter to a code statement in source
|
|
297
|
+
code. The default value is None, which means there is no parameter input in the cell.
|
|
298
|
+
kwargs ({str: ParamType}) - The key type must be str, and the value type must be ParamType. The
|
|
299
|
+
input parameter name used to describe the formal parameter with a keyword. Enter the name in the source
|
|
300
|
+
code as the 'kwargs' in the statement expression. The default value is None, which means there is no
|
|
301
|
+
'kwargs' input.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
An instance of `Node`.
|
|
305
|
+
"""
|
|
306
|
+
if not isinstance(func, types.FunctionType):
|
|
307
|
+
raise TypeError("The 'func' parameter must be a Function, but got ", type(func))
|
|
308
|
+
|
|
309
|
+
_package = func.__globals__['__package__']
|
|
310
|
+
func_name = ".".join([_package, func.__name__]) if _package else func.__name__
|
|
311
|
+
|
|
312
|
+
ast_assign = self.create_assign_node(targets, func_name, args, kwargs)
|
|
313
|
+
scope_targets = [ScopedValue.create_naming_value(targets[0])]
|
|
314
|
+
scope_func = ScopedValue.create_naming_value(func_name, "")
|
|
315
|
+
call_args = list()
|
|
316
|
+
for arg in args:
|
|
317
|
+
if isinstance(arg, Node):
|
|
318
|
+
call_args.append(ScopedValue.create_variable_value(arg.get_targets()[0].value))
|
|
319
|
+
else:
|
|
320
|
+
call_args.append(ScopedValue.create_variable_value(arg))
|
|
321
|
+
call_kwargs = {}
|
|
322
|
+
for k, v in kwargs.items():
|
|
323
|
+
call_kwargs[k] = ScopedValue.create_variable_value(v)
|
|
324
|
+
node = self.inner_create_call_function(func_name, ast_assign, scope_func, func, scope_targets, call_args,
|
|
325
|
+
call_kwargs)
|
|
326
|
+
return node
|
|
327
|
+
|
|
328
|
+
def create_assign_node(self, targets, func_name, args, kwargs):
|
|
329
|
+
"""
|
|
330
|
+
Create a ast.Assign type node.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
targets (list): _description_
|
|
334
|
+
func_name (_type_): _description_
|
|
335
|
+
args (_type_): _description_
|
|
336
|
+
kwargs (_type_): _description_
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
_type_: _description_
|
|
340
|
+
"""
|
|
341
|
+
# create targets
|
|
342
|
+
ast_targets = [ast_creator_registry.get("Name")(targets)]
|
|
343
|
+
# create call
|
|
344
|
+
ast_func = ast_creator_registry.get("Attribute")(func_name)
|
|
345
|
+
ast_args = ast_creator_registry.get("Args")(args)
|
|
346
|
+
ast_kwargs = ast_creator_registry.get("KwArgs")(kwargs) if kwargs else []
|
|
347
|
+
ast_value = ast_creator_registry.get("Call")(func=ast_func, args=ast_args, keywords=ast_kwargs)
|
|
348
|
+
# create assign
|
|
349
|
+
ast_node = ast_creator_registry.get("Assign")(targets=ast_targets, value=ast_value)
|
|
350
|
+
return ast_node
|
|
351
|
+
|
|
352
|
+
def inner_create_call_function(self, node_name, ast_node, func_name, func, targets, args, kwargs):
|
|
353
|
+
'''
|
|
354
|
+
Instantiate an instance of node whose type is `CallFunction`.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
node_name (str): Name of node.
|
|
358
|
+
func_name (str): Name of function.
|
|
359
|
+
ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast.
|
|
360
|
+
targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
361
|
+
func ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class.
|
|
362
|
+
args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class.
|
|
363
|
+
kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node`
|
|
364
|
+
class.
|
|
365
|
+
'''
|
|
366
|
+
logger.info(f"func name: {func_name}; func: {func}; targets: {targets}; args: {args}; kwargs: {kwargs}")
|
|
367
|
+
node = Node(NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, func)
|
|
368
|
+
node.set_belong_symbol_tree(self)
|
|
369
|
+
return node
|
|
370
|
+
|
|
283
371
|
def get_ori_cls_name(self) -> str:
|
|
284
372
|
"""
|
|
285
373
|
Get class name of original network.
|
|
@@ -374,12 +462,6 @@ class SymbolTree(Observer, Observable):
|
|
|
374
462
|
self._init_func_ast = ast_node
|
|
375
463
|
|
|
376
464
|
def get_inputs(self):
|
|
377
|
-
"""
|
|
378
|
-
Getter of `_inputs` which represents parameters of current forward method.
|
|
379
|
-
|
|
380
|
-
Returns:
|
|
381
|
-
A list of instance of Node whose node_type is NodeType.Input as input nodes.
|
|
382
|
-
"""
|
|
383
465
|
return self._inputs
|
|
384
466
|
|
|
385
467
|
def get_head_node(self):
|
|
@@ -400,17 +482,6 @@ class SymbolTree(Observer, Observable):
|
|
|
400
482
|
"""
|
|
401
483
|
return self._origin_network
|
|
402
484
|
|
|
403
|
-
def get_global_vars(self):
|
|
404
|
-
"""Get global variables."""
|
|
405
|
-
return self._global_vars
|
|
406
|
-
|
|
407
|
-
def add_global_vars(self, key: str, value):
|
|
408
|
-
"""Add global variables."""
|
|
409
|
-
if self._global_vars.get(key) is not None:
|
|
410
|
-
logger.info(f"The key '{key}' is duplicated")
|
|
411
|
-
return
|
|
412
|
-
self._global_vars[key] = value
|
|
413
|
-
|
|
414
485
|
def get_nodes_dict(self):
|
|
415
486
|
"""Get dict of nodes"""
|
|
416
487
|
return self._nodes
|
|
@@ -530,7 +601,6 @@ class SymbolTree(Observer, Observable):
|
|
|
530
601
|
RuntimeError: If 'node_or_name' is not belong to this SymbolTree or any sub-SymbolTree of current
|
|
531
602
|
SymbolTree.
|
|
532
603
|
"""
|
|
533
|
-
|
|
534
604
|
node = self._get_real_node(node_or_name)
|
|
535
605
|
if node is None:
|
|
536
606
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
@@ -569,7 +639,12 @@ class SymbolTree(Observer, Observable):
|
|
|
569
639
|
RuntimeError: If 'position' is not in current SymbolTree.
|
|
570
640
|
RuntimeError: If corresponding ast node is not an ast.Assign when 'insert_to_ast' is True.
|
|
571
641
|
"""
|
|
572
|
-
|
|
642
|
+
if position is not None and hasattr(position.node, "container"):
|
|
643
|
+
cellcontainer = getattr(position.node, "container")
|
|
644
|
+
index = cellcontainer.node_list.index(position.node)
|
|
645
|
+
index = index if position.before_node else index + 1
|
|
646
|
+
cellcontainer.insert(index, node)
|
|
647
|
+
return node
|
|
573
648
|
# if position in current SymbolTree
|
|
574
649
|
if position is not None and position.symbol_tree is not self:
|
|
575
650
|
raise RuntimeError("Position is not in current SymbolTree:", position)
|
|
@@ -594,37 +669,7 @@ class SymbolTree(Observer, Observable):
|
|
|
594
669
|
self._node_visitor.append_node(node)
|
|
595
670
|
# update init-function-ast and construct-function-ast
|
|
596
671
|
if insert_to_ast:
|
|
597
|
-
|
|
598
|
-
node_ast = node.get_ast()
|
|
599
|
-
if not isinstance(node_ast, ast.Assign):
|
|
600
|
-
raise RuntimeError("Only support insert cell op now")
|
|
601
|
-
if isinstance(node, TreeNode):
|
|
602
|
-
global_vars_key = node.get_name() + "_args"
|
|
603
|
-
self.add_global_vars(global_vars_key, node.symbol_tree.get_global_vars())
|
|
604
|
-
args_call = AstModifier.create_call(ScopedValue.create_naming_value("get", "global_vars"),
|
|
605
|
-
[ScopedValue.create_variable_value(global_vars_key)])
|
|
606
|
-
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
|
607
|
-
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
|
608
|
-
|
|
609
|
-
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
|
|
610
|
-
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
|
|
611
|
-
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
|
|
612
|
-
|
|
613
|
-
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
614
|
-
None if position is None else position.node.get_ast(),
|
|
615
|
-
position.before_node)
|
|
616
|
-
sub_stree: SymbolTree = node.symbol_tree
|
|
617
|
-
from .symbol_tree_builder import SymbolTreeBuilder
|
|
618
|
-
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
|
|
619
|
-
else:
|
|
620
|
-
AstModifier.insert_assign_to_function(self._init_func_ast,
|
|
621
|
-
targets=[ScopedValue(ValueType.NamingValue, "self", node_name)],
|
|
622
|
-
expr=ScopedValue(ValueType.NamingValue, "global_vars", "get"),
|
|
623
|
-
args=[ScopedValue(ValueType.StringValue, "", node_name)])
|
|
624
|
-
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
625
|
-
None if position is None else position.node.get_ast(),
|
|
626
|
-
position.before_node)
|
|
627
|
-
self._global_vars[node_name] = node.get_instance()
|
|
672
|
+
self._insert_to_ast_while_insert_node(node, position)
|
|
628
673
|
return node
|
|
629
674
|
|
|
630
675
|
def append_node(self, node: Node, append_to_ast: bool = True) -> Node:
|
|
@@ -723,8 +768,9 @@ class SymbolTree(Observer, Observable):
|
|
|
723
768
|
Returns:
|
|
724
769
|
An instance of python node which has been appended to SymbolTree.
|
|
725
770
|
"""
|
|
726
|
-
logger.
|
|
771
|
+
logger.info("Ignoring unsupported node (%s) (%s).", type(ast_node).__name__, type(ast_scope).__name__)
|
|
727
772
|
node_name = self._node_name_namer.get_name(type(ast_node).__name__)
|
|
773
|
+
self._update_names_for_unique(ast_node)
|
|
728
774
|
node = Node.create_python_node(ast_node, node_name)
|
|
729
775
|
self._insert_node(Position.create(self, self._tail, False), node)
|
|
730
776
|
return node
|
|
@@ -767,6 +813,10 @@ class SymbolTree(Observer, Observable):
|
|
|
767
813
|
node = self._get_real_node(node_or_name)
|
|
768
814
|
if node is None:
|
|
769
815
|
raise RuntimeError("Node is not belong to current SymbolTree: ", node_or_name)
|
|
816
|
+
if hasattr(node, "container"):
|
|
817
|
+
cellcontainer = getattr(node, "container")
|
|
818
|
+
cellcontainer.erase(node)
|
|
819
|
+
return node
|
|
770
820
|
ret = AstModifier.erase_ast_from_function(self._root_ast, node.get_ast())
|
|
771
821
|
if not ret:
|
|
772
822
|
raise RuntimeError("node not in function ast tree.")
|
|
@@ -776,6 +826,7 @@ class SymbolTree(Observer, Observable):
|
|
|
776
826
|
value.isolate()
|
|
777
827
|
break
|
|
778
828
|
self._topo_mgr.on_erase_node(node)
|
|
829
|
+
self._deleted_node.append(node.get_name())
|
|
779
830
|
return node
|
|
780
831
|
|
|
781
832
|
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
|
@@ -800,6 +851,9 @@ class SymbolTree(Observer, Observable):
|
|
|
800
851
|
RuntimeError: If 'old_node' is not belong to current SymbolTree.
|
|
801
852
|
"""
|
|
802
853
|
|
|
854
|
+
if hasattr(old_node, "container"):
|
|
855
|
+
self._replace_container_node(old_node, new_nodes)
|
|
856
|
+
return new_nodes[0]
|
|
803
857
|
real_old_node = self._get_real_node(old_node)
|
|
804
858
|
if real_old_node is None:
|
|
805
859
|
raise RuntimeError("Old node is not belong to current SymbolTree:", old_node)
|
|
@@ -882,7 +936,6 @@ class SymbolTree(Observer, Observable):
|
|
|
882
936
|
self.set_node_arg(real_dst_node, arg_idx, new_arg)
|
|
883
937
|
|
|
884
938
|
def print_node_tabulate(self):
|
|
885
|
-
"""Print node information of graph."""
|
|
886
939
|
try:
|
|
887
940
|
from tabulate import tabulate
|
|
888
941
|
except ImportError:
|
|
@@ -898,6 +951,13 @@ class SymbolTree(Observer, Observable):
|
|
|
898
951
|
dump_st = SymbolTreeDumper(self)
|
|
899
952
|
dump_st.dump()
|
|
900
953
|
|
|
954
|
+
def update_module_ast(self):
|
|
955
|
+
for node in self._external_func_ast:
|
|
956
|
+
self._module_ast.body.append(node)
|
|
957
|
+
for node in self._father_class_ast:
|
|
958
|
+
index = self._module_ast.body.index(self._class_ast)
|
|
959
|
+
self._module_ast.body.insert(index, node)
|
|
960
|
+
|
|
901
961
|
def get_code(self) -> str:
|
|
902
962
|
"""
|
|
903
963
|
Get source code of modified network.
|
|
@@ -909,6 +969,7 @@ class SymbolTree(Observer, Observable):
|
|
|
909
969
|
if self._init_func_ast:
|
|
910
970
|
self._remove_unused_field()
|
|
911
971
|
self._remove_duplicated_import()
|
|
972
|
+
self.update_module_ast()
|
|
912
973
|
ast.fix_missing_locations(self._module_ast)
|
|
913
974
|
# Find all ast.ClassDef which can be export to code
|
|
914
975
|
# Replace duplicated ast.ClassDef reference in main-ClassDef
|
|
@@ -943,21 +1004,20 @@ class SymbolTree(Observer, Observable):
|
|
|
943
1004
|
A network object.
|
|
944
1005
|
"""
|
|
945
1006
|
cls = self._get_cls_through_file()
|
|
946
|
-
|
|
1007
|
+
new_net = cls(self._origin_network)
|
|
1008
|
+
self._merge_origin_property(new_net)
|
|
1009
|
+
return new_net
|
|
947
1010
|
|
|
948
1011
|
def set_saved_file_name(self, file_name: str):
|
|
949
|
-
"""Sets the filename used to save the network."""
|
|
950
1012
|
if file_name.endswith(".py"):
|
|
951
1013
|
self._saved_file_name = file_name
|
|
952
1014
|
else:
|
|
953
1015
|
self._saved_file_name = file_name + ".py"
|
|
954
1016
|
|
|
955
1017
|
def get_saved_file_name(self):
|
|
956
|
-
"""Gets the filename used to save the network."""
|
|
957
1018
|
return self._saved_file_name
|
|
958
1019
|
|
|
959
1020
|
def save_network_to_file(self):
|
|
960
|
-
"""Save the modified network to a file."""
|
|
961
1021
|
abs_path = os.path.abspath(self._saved_file_name)
|
|
962
1022
|
if os.path.isfile(abs_path):
|
|
963
1023
|
os.remove(abs_path)
|
|
@@ -966,6 +1026,58 @@ class SymbolTree(Observer, Observable):
|
|
|
966
1026
|
f.write(source.encode('utf-8'))
|
|
967
1027
|
f.flush()
|
|
968
1028
|
|
|
1029
|
+
def update_scope_for_unique(self, node: Union[ast.Attribute, ast.Call, ast.Subscript]):
|
|
1030
|
+
""" Update scope of ast node because of unique-ing of targets of other nodes. """
|
|
1031
|
+
if isinstance(node, ast.Call):
|
|
1032
|
+
self.update_scope_for_unique(node.func)
|
|
1033
|
+
return
|
|
1034
|
+
if not isinstance(node, (ast.Attribute, ast.Subscript)):
|
|
1035
|
+
logger.warning(f"Cannot update node {astunparse.unparse(node)} for unique, type of node should "
|
|
1036
|
+
f"be one of (ast.Attribute, ast.Subscript).")
|
|
1037
|
+
return
|
|
1038
|
+
scope = node.value
|
|
1039
|
+
if not isinstance(scope, ast.Name):
|
|
1040
|
+
self.update_scope_for_unique(scope)
|
|
1041
|
+
return
|
|
1042
|
+
scope_name = scope.id
|
|
1043
|
+
scope_name_unique = self._target_namer.get_real_arg(scope_name)
|
|
1044
|
+
scope.id = scope_name_unique
|
|
1045
|
+
|
|
1046
|
+
def _insert_to_ast_while_insert_node(self, node: Node, position: Optional[Position]):
|
|
1047
|
+
""" insert_to_ast_while_insert_node. """
|
|
1048
|
+
node.set_func(ScopedValue.create_naming_value(node.get_name(), "self"))
|
|
1049
|
+
node_ast = node.get_ast()
|
|
1050
|
+
if not isinstance(node_ast, ast.Assign):
|
|
1051
|
+
raise RuntimeError("Only support insert cell op now")
|
|
1052
|
+
if isinstance(node, TreeNode):
|
|
1053
|
+
setattr(self._origin_network, node.get_name(), node.get_instance())
|
|
1054
|
+
args_call = AstModifier.create_call(ScopedValue(ValueType.NamingValue, "", "getattr"),
|
|
1055
|
+
[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
1056
|
+
ScopedValue(ValueType.StringValue, "", node.get_name())])
|
|
1057
|
+
value = ast.Call(func=ast.Name(node.symbol_tree.get_opt_cls_name(), ast.Store(), lineno=0,
|
|
1058
|
+
col_offset=0), args=[args_call], keywords=[], lineno=0, col_offset=0)
|
|
1059
|
+
|
|
1060
|
+
ast_target = ast.Name("self." + node.get_name(), ast.Store(), lineno=0, col_offset=0)
|
|
1061
|
+
assign = ast.Assign(targets=[ast_target], value=value, lineno=0, col_offset=0)
|
|
1062
|
+
AstModifier.insert_assign_ast_to_function(self._init_func_ast, assign)
|
|
1063
|
+
|
|
1064
|
+
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
1065
|
+
None if position is None else position.node.get_ast(),
|
|
1066
|
+
position.before_node)
|
|
1067
|
+
sub_stree: SymbolTree = node.symbol_tree
|
|
1068
|
+
from .symbol_tree_builder import SymbolTreeBuilder
|
|
1069
|
+
SymbolTreeBuilder.merge_module_of_subtree(self, sub_stree)
|
|
1070
|
+
else:
|
|
1071
|
+
AstModifier.insert_assign_to_function(self._init_func_ast,
|
|
1072
|
+
targets=[ScopedValue(ValueType.NamingValue, "self", node.get_name())],
|
|
1073
|
+
expr=ScopedValue(ValueType.NamingValue, "", "getattr"),
|
|
1074
|
+
args=[ScopedValue(ValueType.NamingValue, "", "obj"),
|
|
1075
|
+
ScopedValue(ValueType.StringValue, "", node.get_name())])
|
|
1076
|
+
AstModifier.insert_assign_ast_to_function(self._root_ast, node_ast,
|
|
1077
|
+
None if position is None else position.node.get_ast(),
|
|
1078
|
+
position.before_node)
|
|
1079
|
+
setattr(self._origin_network, node.get_name(), node.get_instance())
|
|
1080
|
+
|
|
969
1081
|
def _remove_unused_import(self):
|
|
970
1082
|
"""remove unused import in self._module_ast"""
|
|
971
1083
|
str_checker = StrChecker(self._module_ast)
|
|
@@ -987,38 +1099,43 @@ class SymbolTree(Observer, Observable):
|
|
|
987
1099
|
else:
|
|
988
1100
|
body.names.remove(alias)
|
|
989
1101
|
|
|
1102
|
+
def _replace_container_node(self, old_node, new_nodes):
|
|
1103
|
+
cellcontainer = getattr(old_node, "container")
|
|
1104
|
+
index = cellcontainer.node_list.index(old_node)
|
|
1105
|
+
for n in reversed(new_nodes):
|
|
1106
|
+
cellcontainer.insert(index, n)
|
|
1107
|
+
index = cellcontainer.node_list.index(old_node)
|
|
1108
|
+
cellcontainer.erase(old_node)
|
|
1109
|
+
|
|
990
1110
|
def _filter_out_to_delete_field(self, to_delete_field):
|
|
991
1111
|
"""filter out used field from `to_delete_field`"""
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
if to_delete_field.get("_handler"):
|
|
995
|
-
to_delete_field.pop("_handler")
|
|
996
|
-
# filter field used in node of construct
|
|
997
|
-
for node in self._nodes.values():
|
|
998
|
-
if node.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree):
|
|
999
|
-
func: ScopedValue = node.get_func()
|
|
1000
|
-
if func.scope == "self" and to_delete_field.get(func.value):
|
|
1001
|
-
to_delete_field.pop(func.value)
|
|
1002
|
-
if node.get_node_type() == NodeType.CallMethod and node.get_func() == PASS_THROUGH_METHOD:
|
|
1003
|
-
var_name = node.get_args()[0].value
|
|
1004
|
-
if to_delete_field.get(var_name):
|
|
1005
|
-
to_delete_field.pop(var_name)
|
|
1006
|
-
# filter field used in test-of-if
|
|
1007
|
-
for body in self._root_ast.body:
|
|
1008
|
-
if not isinstance(body, ast.If):
|
|
1112
|
+
for func_def in self._class_ast.body:
|
|
1113
|
+
if not isinstance(func_def, ast.FunctionDef):
|
|
1009
1114
|
continue
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1115
|
+
if func_def.name != "__init__":
|
|
1116
|
+
to_delete_to_delete_keys = []
|
|
1117
|
+
property_checker = CheckPropertyIsUsed(func_def)
|
|
1118
|
+
for key, _ in self._deleted_field.items():
|
|
1119
|
+
if property_checker.check("self", key):
|
|
1120
|
+
to_delete_to_delete_keys.append(key)
|
|
1121
|
+
property_checker = CheckPropertyIsUsed(func_def)
|
|
1122
|
+
for key in to_delete_to_delete_keys:
|
|
1123
|
+
self._deleted_field.pop(key)
|
|
1124
|
+
else:
|
|
1125
|
+
for body in func_def.body:
|
|
1126
|
+
if not isinstance(body, ast.If):
|
|
1127
|
+
continue
|
|
1128
|
+
test = body.test
|
|
1129
|
+
field_finder = FieldFinder(test)
|
|
1130
|
+
to_delete_to_delete_keys = []
|
|
1131
|
+
for key, _ in self._deleted_field.items():
|
|
1132
|
+
if field_finder.check(key):
|
|
1133
|
+
to_delete_to_delete_keys.append(key)
|
|
1134
|
+
for key in to_delete_to_delete_keys:
|
|
1135
|
+
self._deleted_field.pop(key)
|
|
1018
1136
|
|
|
1019
1137
|
def _remove_unused_field(self):
|
|
1020
1138
|
"""remove unused field in __init__ function"""
|
|
1021
|
-
to_delete_field = {}
|
|
1022
1139
|
multi_targets = []
|
|
1023
1140
|
for index, body in enumerate(self._init_func_ast.body):
|
|
1024
1141
|
if not isinstance(body, ast.Assign):
|
|
@@ -1027,12 +1144,12 @@ class SymbolTree(Observer, Observable):
|
|
|
1027
1144
|
for target in targets:
|
|
1028
1145
|
if isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) \
|
|
1029
1146
|
and target.value.id == "self":
|
|
1030
|
-
|
|
1147
|
+
self._deleted_field[target.attr] = index
|
|
1031
1148
|
if len(targets) > 1:
|
|
1032
1149
|
multi_targets.append(index)
|
|
1033
|
-
self._filter_out_to_delete_field(
|
|
1150
|
+
self._filter_out_to_delete_field(self._deleted_field)
|
|
1034
1151
|
for i in range(len(self._init_func_ast.body) - 1, -1, -1):
|
|
1035
|
-
if i in
|
|
1152
|
+
if i in self._deleted_field.values():
|
|
1036
1153
|
if i in multi_targets:
|
|
1037
1154
|
raise RuntimeError("Can not erase field ast node in __init__ function because of multi-targets")
|
|
1038
1155
|
AstModifier.erase_ast_from_function(self._init_func_ast, self._init_func_ast.body[i])
|
|
@@ -1050,12 +1167,9 @@ class SymbolTree(Observer, Observable):
|
|
|
1050
1167
|
self._module_ast.body.remove(body)
|
|
1051
1168
|
|
|
1052
1169
|
def _get_real_node(self, node_or_name: Union[Node, str]) -> Optional[Node]:
|
|
1053
|
-
if isinstance(node_or_name, Node):
|
|
1054
|
-
result = self.get_node(node_or_name.get_name())
|
|
1055
|
-
return result if result is node_or_name else None
|
|
1056
1170
|
if isinstance(node_or_name, str):
|
|
1057
1171
|
return self.get_node(node_or_name)
|
|
1058
|
-
return
|
|
1172
|
+
return node_or_name
|
|
1059
1173
|
|
|
1060
1174
|
def _insert_tree(self, position: Position, root: Node, insert_to_ast: bool = True) -> Node:
|
|
1061
1175
|
"""
|
|
@@ -1204,7 +1318,7 @@ class SymbolTree(Observer, Observable):
|
|
|
1204
1318
|
raise TypeError("value should be ScopedValue, got: ", type(value))
|
|
1205
1319
|
if value.type == ValueType.CustomObjValue:
|
|
1206
1320
|
field = self._node_name_namer.get_name(f"var_{type(value.value).__name__}")
|
|
1207
|
-
self.
|
|
1321
|
+
setattr(self._origin_network, field, value.value)
|
|
1208
1322
|
init_targets = [ScopedValue.create_naming_value(field, "self")]
|
|
1209
1323
|
AstModifier.append_global_vars_expr_to_init(self._init_func_ast, init_targets, field)
|
|
1210
1324
|
result[arg] = init_targets[0]
|
|
@@ -1222,19 +1336,34 @@ class SymbolTree(Observer, Observable):
|
|
|
1222
1336
|
Returns:
|
|
1223
1337
|
A class handle.
|
|
1224
1338
|
"""
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1339
|
+
self._update_container()
|
|
1340
|
+
file_path = os.getcwd()
|
|
1341
|
+
file_path = os.path.join(file_path, "rewritten_network")
|
|
1342
|
+
if not os.path.exists(file_path):
|
|
1343
|
+
os.mkdir(file_path)
|
|
1344
|
+
file_name = "{0}_{1}.py".format(self._opt_cls_name, id(self))
|
|
1345
|
+
network_file = os.path.join(file_path, file_name)
|
|
1346
|
+
with os.fdopen(os.open(network_file, os.O_WRONLY | os.O_CREAT, stat.S_IRWXU), 'wb') as f:
|
|
1347
|
+
source = self.get_code()
|
|
1348
|
+
f.write(source.encode('utf-8'))
|
|
1349
|
+
f.flush()
|
|
1350
|
+
os.fsync(f)
|
|
1351
|
+
tmp_module_path, tmp_module_file = os.path.split(network_file)
|
|
1235
1352
|
tmp_module_name = tmp_module_file[:-3]
|
|
1236
1353
|
sys.path.append(tmp_module_path)
|
|
1237
|
-
tmp_module =
|
|
1354
|
+
tmp_module = None
|
|
1355
|
+
|
|
1356
|
+
i = 0
|
|
1357
|
+
while not tmp_module:
|
|
1358
|
+
try:
|
|
1359
|
+
tmp_module = importlib.import_module(tmp_module_name)
|
|
1360
|
+
except ModuleNotFoundError:
|
|
1361
|
+
if i > 10:
|
|
1362
|
+
break
|
|
1363
|
+
time.sleep(0.1)
|
|
1364
|
+
i += 1
|
|
1365
|
+
if not tmp_module:
|
|
1366
|
+
logger.error(f"load module {tmp_module_name} failed.")
|
|
1238
1367
|
network_cls = getattr(tmp_module, self._opt_cls_name)
|
|
1239
1368
|
if network_cls is None:
|
|
1240
1369
|
raise RuntimeError("Can not find network class:", self._opt_cls_name)
|
|
@@ -1243,3 +1372,87 @@ class SymbolTree(Observer, Observable):
|
|
|
1243
1372
|
def _on_change(self, event: Event):
|
|
1244
1373
|
self._modified = True
|
|
1245
1374
|
self.changed(event)
|
|
1375
|
+
|
|
1376
|
+
def _update_container(self):
|
|
1377
|
+
"""Update instance of node in container."""
|
|
1378
|
+
for node in self.nodes():
|
|
1379
|
+
index = 0
|
|
1380
|
+
if node.get_node_type() == NodeType.CellContainer:
|
|
1381
|
+
for n in node.node_list:
|
|
1382
|
+
if not n.valid:
|
|
1383
|
+
continue
|
|
1384
|
+
if n.get_node_type() == NodeType.Tree:
|
|
1385
|
+
obj = n.symbol_tree.get_network()
|
|
1386
|
+
node.get_instance()[index] = obj
|
|
1387
|
+
else:
|
|
1388
|
+
node.get_instance()[index] = n.get_instance()
|
|
1389
|
+
index += 1
|
|
1390
|
+
|
|
1391
|
+
def _cal_difference_set(self, input, other):
|
|
1392
|
+
"""Calculate different set of two sets."""
|
|
1393
|
+
set1 = set(input)
|
|
1394
|
+
set2 = set(other)
|
|
1395
|
+
return set1 - set2
|
|
1396
|
+
|
|
1397
|
+
def _merge_origin_property(self, new_net):
|
|
1398
|
+
"""Merge property of two network."""
|
|
1399
|
+
tmp = self._cal_difference_set(dir(self._origin_network), dir(new_net))
|
|
1400
|
+
new_attr_names = self._cal_difference_set(tmp, self._deleted_field.keys())
|
|
1401
|
+
for name in new_attr_names:
|
|
1402
|
+
setattr(new_net, name, getattr(self._origin_network, name))
|
|
1403
|
+
# merger cells
|
|
1404
|
+
cells = self._cal_difference_set(self._origin_network.name_cells().keys(), new_net.name_cells().keys())
|
|
1405
|
+
cells = self._cal_difference_set(cells, self._deleted_node)
|
|
1406
|
+
for c in cells:
|
|
1407
|
+
new_net.insert_child_to_cell(c, self._origin_network.name_cells()[c])
|
|
1408
|
+
# merge primitives
|
|
1409
|
+
primitives = self._cal_difference_set(self._origin_network._primitives.keys(), new_net._primitives.keys())
|
|
1410
|
+
for p in primitives:
|
|
1411
|
+
new_net._primitives[p] = self._origin_network._primitives[p]
|
|
1412
|
+
|
|
1413
|
+
def _update_names_for_unique(self, node: ast.AST):
|
|
1414
|
+
""" Update names of ast nodes for unique. """
|
|
1415
|
+
if isinstance(node, (ast.For, ast.If, ast.While)):
|
|
1416
|
+
self._update_names_for_unique_branchs(node)
|
|
1417
|
+
elif isinstance(node, ast.Assign):
|
|
1418
|
+
self._update_names_for_unique(node.value)
|
|
1419
|
+
for target in node.targets:
|
|
1420
|
+
self._update_names_for_unique(target)
|
|
1421
|
+
elif isinstance(node, ast.Call):
|
|
1422
|
+
if isinstance(node.func, ast.Attribute):
|
|
1423
|
+
self._update_names_for_unique(node.func.value)
|
|
1424
|
+
for arg in node.args:
|
|
1425
|
+
self._update_names_for_unique(arg)
|
|
1426
|
+
for keyword in node.keywords:
|
|
1427
|
+
self._update_names_for_unique(keyword)
|
|
1428
|
+
elif isinstance(node, ast.UnaryOp):
|
|
1429
|
+
self._update_names_for_unique(node.operand)
|
|
1430
|
+
elif isinstance(node, ast.BinOp):
|
|
1431
|
+
self._update_names_for_unique(node.left)
|
|
1432
|
+
self._update_names_for_unique(node.right)
|
|
1433
|
+
elif isinstance(node, (ast.Attribute, ast.Subscript, ast.Return)):
|
|
1434
|
+
self._update_names_for_unique(node.value)
|
|
1435
|
+
elif isinstance(node, (ast.List, ast.Tuple)):
|
|
1436
|
+
for elt in node.elts:
|
|
1437
|
+
self._update_names_for_unique(elt)
|
|
1438
|
+
elif isinstance(node, ast.Compare):
|
|
1439
|
+
for comparator in node.comparators:
|
|
1440
|
+
self._update_names_for_unique(comparator)
|
|
1441
|
+
elif isinstance(node, ast.Name):
|
|
1442
|
+
node.id = self._target_namer.get_real_arg(node.id)
|
|
1443
|
+
|
|
1444
|
+
def _update_names_for_unique_branchs(self, node: Union[ast.For, ast.If, ast.While]):
|
|
1445
|
+
""" Update names of ast nodes for unique with ast.For, ast.If or ast.While """
|
|
1446
|
+
if isinstance(node, ast.For):
|
|
1447
|
+
self._update_names_for_unique(node.target)
|
|
1448
|
+
self._update_names_for_unique(node.iter)
|
|
1449
|
+
for body in node.body:
|
|
1450
|
+
self._update_names_for_unique(body)
|
|
1451
|
+
for body in node.orelse:
|
|
1452
|
+
self._update_names_for_unique(body)
|
|
1453
|
+
elif isinstance(node, (ast.If, ast.While)):
|
|
1454
|
+
self._update_names_for_unique(node.test)
|
|
1455
|
+
for body in node.body:
|
|
1456
|
+
self._update_names_for_unique(body)
|
|
1457
|
+
for body in node.orelse:
|
|
1458
|
+
self._update_names_for_unique(body)
|