mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -14,14 +14,20 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Ast optimizer for flatten recursive call."""
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
import sys
|
|
18
|
+
from typing import Any, Tuple, List
|
|
19
|
+
import keyword
|
|
18
20
|
import ast
|
|
19
|
-
from ast import FunctionDef
|
|
20
|
-
import astunparse
|
|
21
21
|
|
|
22
22
|
from mindspore import log as logger
|
|
23
23
|
from ..common import error_str
|
|
24
24
|
|
|
25
|
+
if sys.version_info >= (3, 9):
|
|
26
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
27
|
+
else:
|
|
28
|
+
import astunparse
|
|
29
|
+
|
|
30
|
+
FLATTEN_BLACK_LIST = ["set_vertex_attr",]
|
|
25
31
|
|
|
26
32
|
class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
27
33
|
"""Ast optimizer for flatten recursive call."""
|
|
@@ -40,17 +46,35 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
40
46
|
ast.BoolOp: ["values"],
|
|
41
47
|
ast.UnaryOp: ["operand"],
|
|
42
48
|
ast.Compare: ["left", "comparators"],
|
|
49
|
+
ast.If: ["test"]
|
|
43
50
|
}
|
|
51
|
+
self._transform_functions = []
|
|
52
|
+
self._transform_if = False
|
|
53
|
+
self._symbol_tree = None # Used to get unique name
|
|
44
54
|
|
|
45
55
|
@staticmethod
|
|
46
|
-
def
|
|
56
|
+
def _check_flatten_black_list(node: ast.AST):
|
|
57
|
+
"""Check whether node in flatten black list"""
|
|
58
|
+
func_name = ""
|
|
59
|
+
# Get func name of node
|
|
60
|
+
if isinstance(node, ast.Call):
|
|
61
|
+
if isinstance(node.func, ast.Name):
|
|
62
|
+
func_name = node.func.id
|
|
63
|
+
elif isinstance(node.func, ast.Attribute):
|
|
64
|
+
func_name = node.func.attr
|
|
65
|
+
# Check func name of node
|
|
66
|
+
if func_name and func_name in FLATTEN_BLACK_LIST:
|
|
67
|
+
return True
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
def _generate_target_name(self, node: ast.AST, target_names):
|
|
47
71
|
"""Generate unique target name."""
|
|
48
72
|
if isinstance(node, ast.Call):
|
|
49
73
|
func = node.func
|
|
50
74
|
if isinstance(func, ast.Name):
|
|
51
|
-
target_name = func.id
|
|
75
|
+
target_name = func.id + "_var"
|
|
52
76
|
elif isinstance(func, ast.Attribute):
|
|
53
|
-
target_name = func.attr
|
|
77
|
+
target_name = func.attr + "_var"
|
|
54
78
|
else:
|
|
55
79
|
logger.info("unhandled type of func of ast.Call while generating new target name: %s ", type(func))
|
|
56
80
|
target_name = "function"
|
|
@@ -67,30 +91,33 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
67
91
|
else:
|
|
68
92
|
logger.info("unhandled type of node while generating new target name: %s ", type(node))
|
|
69
93
|
target_name = type(node).__name__.lower() + "_var"
|
|
94
|
+
# avoid python keyword
|
|
95
|
+
if keyword.iskeyword(target_name):
|
|
96
|
+
target_name = target_name + "_var"
|
|
70
97
|
suffix = 0
|
|
71
98
|
result = target_name
|
|
72
99
|
while result in target_names:
|
|
73
100
|
suffix += 1
|
|
74
101
|
result = f"{target_name}_{suffix}"
|
|
102
|
+
if self._symbol_tree:
|
|
103
|
+
result = self._symbol_tree.unique_name(result)
|
|
75
104
|
target_names.append(result)
|
|
76
105
|
return result
|
|
77
106
|
|
|
78
|
-
|
|
79
|
-
def _create_new_assign_node(node: ast.AST, target_names) -> Tuple[str, ast.AST]:
|
|
107
|
+
def _create_new_assign_node(self, node: ast.AST, target_names) -> Tuple[str, ast.AST]:
|
|
80
108
|
"""Create new assign node to be inserted into ast.FunctionDef."""
|
|
81
109
|
if isinstance(node, (ast.Name, ast.Constant, ast.Num, ast.Str, ast.NameConstant, ast.Bytes, ast.Ellipsis)):
|
|
82
110
|
return "", node
|
|
83
|
-
new_target_name =
|
|
111
|
+
new_target_name = self._generate_target_name(node, target_names)
|
|
84
112
|
return new_target_name, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node)
|
|
85
113
|
|
|
86
|
-
|
|
87
|
-
def _flatten_list(node_list, target_names):
|
|
114
|
+
def _flatten_list(self, node_list, target_names):
|
|
88
115
|
"""Flatten a list of node."""
|
|
89
116
|
results = list()
|
|
90
117
|
new_list = list()
|
|
91
118
|
for node in node_list:
|
|
92
119
|
if isinstance(node, ast.Call):
|
|
93
|
-
new_target, new_node =
|
|
120
|
+
new_target, new_node = self._create_new_assign_node(node, target_names)
|
|
94
121
|
results.append(new_node)
|
|
95
122
|
new_list.append(ast.Name(id=new_target, ctx=ast.Load()))
|
|
96
123
|
else:
|
|
@@ -99,6 +126,8 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
99
126
|
|
|
100
127
|
def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]:
|
|
101
128
|
"""Flatten recursive statement according to different node type."""
|
|
129
|
+
if FlattenRecursiveStmt._check_flatten_black_list(node):
|
|
130
|
+
return []
|
|
102
131
|
flatten_config = self._flatten_table.get(type(node))
|
|
103
132
|
if flatten_config is None:
|
|
104
133
|
return []
|
|
@@ -108,21 +137,25 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
108
137
|
if isinstance(todos, list):
|
|
109
138
|
new_list = []
|
|
110
139
|
for todo in todos:
|
|
111
|
-
|
|
140
|
+
# Starred expression(e.g. *args) cannot be flatten.
|
|
141
|
+
if isinstance(todo, ast.Starred):
|
|
142
|
+
new_list.append(todo)
|
|
143
|
+
continue
|
|
144
|
+
new_target_name, new_node = self._create_new_assign_node(todo, target_names)
|
|
112
145
|
if id(new_node) == id(todo):
|
|
113
146
|
new_list.append(todo)
|
|
114
147
|
else:
|
|
115
148
|
new_list.append(ast.Name(id=new_target_name, ctx=ast.Load()))
|
|
116
149
|
results.append(new_node)
|
|
117
150
|
if isinstance(todo, (ast.Tuple, tuple)):
|
|
118
|
-
_res, _new_list =
|
|
151
|
+
_res, _new_list = self._flatten_list(new_node.value.elts, [new_target_name])
|
|
119
152
|
new_node.value.elts = _new_list
|
|
120
153
|
results.extend(_res)
|
|
121
154
|
setattr(node, todo_name, new_list)
|
|
122
155
|
elif isinstance(todos, dict):
|
|
123
156
|
new_dict = []
|
|
124
157
|
for key, value in todos:
|
|
125
|
-
new_target_name, new_node =
|
|
158
|
+
new_target_name, new_node = self._create_new_assign_node(value, target_names)
|
|
126
159
|
if id(new_node) == id(value):
|
|
127
160
|
new_dict[key] = value
|
|
128
161
|
else:
|
|
@@ -130,16 +163,15 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
130
163
|
results.append(new_node)
|
|
131
164
|
setattr(node, todo_name, new_dict)
|
|
132
165
|
else:
|
|
133
|
-
new_target_name, new_node =
|
|
166
|
+
new_target_name, new_node = self._create_new_assign_node(todos, target_names)
|
|
134
167
|
if id(new_node) != id(todos):
|
|
135
168
|
setattr(node, todo_name, ast.Name(id=new_target_name, ctx=ast.Load()))
|
|
136
169
|
results.append(new_node)
|
|
137
170
|
return results
|
|
138
171
|
|
|
139
|
-
def
|
|
140
|
-
"""
|
|
141
|
-
for
|
|
142
|
-
child = node.body[function_index]
|
|
172
|
+
def _save_target_names(self, target_names, ast_body: List[ast.AST]):
|
|
173
|
+
"""Saving target names in ast_body before getting unique names."""
|
|
174
|
+
for child in ast_body:
|
|
143
175
|
if isinstance(child, (ast.Assign, ast.Expr)):
|
|
144
176
|
child_value = child.value
|
|
145
177
|
else:
|
|
@@ -151,7 +183,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
151
183
|
continue
|
|
152
184
|
targets = child.targets
|
|
153
185
|
for target in targets:
|
|
154
|
-
if not isinstance(target, (ast.Name, ast.Tuple)):
|
|
186
|
+
if not isinstance(target, (ast.Name, ast.Tuple, ast.List)):
|
|
155
187
|
raise RuntimeError(
|
|
156
188
|
error_str(f"currently only support ast.Name targets, but got ast type "
|
|
157
189
|
f"'{type(target).__name__}'", child_node=target, father_node=child))
|
|
@@ -159,7 +191,7 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
159
191
|
target_name = target.id
|
|
160
192
|
if target_name not in target_names:
|
|
161
193
|
target_names.append(target_name)
|
|
162
|
-
elif isinstance(target, ast.Tuple):
|
|
194
|
+
elif isinstance(target, (ast.Tuple, ast.List)):
|
|
163
195
|
for elt in target.elts:
|
|
164
196
|
if not isinstance(elt, ast.Name):
|
|
165
197
|
raise RuntimeError(
|
|
@@ -170,47 +202,66 @@ class FlattenRecursiveStmt(ast.NodeTransformer):
|
|
|
170
202
|
if target_name not in target_names:
|
|
171
203
|
target_names.append(target_name)
|
|
172
204
|
|
|
173
|
-
def
|
|
174
|
-
"""Traverse
|
|
175
|
-
if node.name != "construct":
|
|
176
|
-
return node
|
|
177
|
-
|
|
205
|
+
def _visit_ast_bodies(self, ast_body: List[ast.AST]):
|
|
206
|
+
"""Traverse nodes in ast_body and flatten nodes recursive."""
|
|
178
207
|
target_names = []
|
|
179
|
-
self.
|
|
180
|
-
index = len(
|
|
208
|
+
self._save_target_names(target_names, ast_body)
|
|
209
|
+
index = len(ast_body) - 1
|
|
181
210
|
while index >= 0:
|
|
182
|
-
child =
|
|
211
|
+
child = ast_body[index]
|
|
183
212
|
if isinstance(child, ast.Assign):
|
|
184
213
|
stmt = child.value
|
|
185
214
|
elif isinstance(child, ast.If):
|
|
186
215
|
if isinstance(child.body[0], ast.Return) and not isinstance(child.test, ast.UnaryOp):
|
|
187
|
-
if isinstance(child.body[0].value, ast.
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
child.body =
|
|
194
|
-
|
|
195
|
-
else:
|
|
196
|
-
stmt = child
|
|
197
|
-
else:
|
|
198
|
-
stmt = child
|
|
216
|
+
if not isinstance(child.body[0].value, (ast.Name, ast.Constant)):
|
|
217
|
+
return_val_ast = child.body[0].value
|
|
218
|
+
return_name = self._generate_target_name(return_val_ast, target_names)
|
|
219
|
+
new_assign_code = f"{return_name} = {astunparse.unparse(return_val_ast)}"
|
|
220
|
+
new_assign_ast = ast.parse(new_assign_code).body[0]
|
|
221
|
+
new_return_ast = ast.parse(f"return {return_name}").body[0]
|
|
222
|
+
child.body = [new_assign_ast, new_return_ast]
|
|
223
|
+
stmt = child
|
|
199
224
|
elif isinstance(child, ast.Expr):
|
|
200
225
|
stmt = child.value
|
|
201
226
|
else:
|
|
202
227
|
stmt = child
|
|
203
228
|
results = self._flatten_statement(stmt, target_names)
|
|
204
229
|
if results:
|
|
205
|
-
results
|
|
206
|
-
|
|
207
|
-
node.body.insert(index, result)
|
|
230
|
+
for result in reversed(results):
|
|
231
|
+
ast_body.insert(index, result)
|
|
208
232
|
index += 1
|
|
209
233
|
index -= 1
|
|
234
|
+
|
|
235
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name
|
|
236
|
+
"""Traverse nodes in _transform_functions and flatten recursive nodes."""
|
|
237
|
+
if node.name not in self._transform_functions:
|
|
238
|
+
return node
|
|
239
|
+
self._visit_ast_bodies(node.body)
|
|
210
240
|
return node
|
|
211
241
|
|
|
212
|
-
def
|
|
242
|
+
def visit_If(self, node: ast.If) -> Any: # pylint: disable=invalid-name
|
|
243
|
+
"""Traverse nodes in if node and flatten recursive nodes."""
|
|
244
|
+
if not self._transform_if:
|
|
245
|
+
return node
|
|
246
|
+
self._visit_ast_bodies(node.body)
|
|
247
|
+
if node.orelse:
|
|
248
|
+
self._visit_ast_bodies(node.orelse)
|
|
249
|
+
return node
|
|
250
|
+
|
|
251
|
+
def transform(self, ast_root, transform_functions=None, stree=None):
|
|
213
252
|
"""Interface of FlattenRecursiveStmt."""
|
|
253
|
+
self._transform_functions = transform_functions if transform_functions else ["construct"]
|
|
254
|
+
self._transform_if = False
|
|
255
|
+
self._symbol_tree = stree
|
|
214
256
|
ast_root = self.visit(ast_root)
|
|
215
257
|
ast_root = ast.fix_missing_locations(ast_root)
|
|
216
258
|
return ast_root
|
|
259
|
+
|
|
260
|
+
def transform_if(self, ast_if, stree=None):
|
|
261
|
+
"""Interface of FlattenRecursiveStmt."""
|
|
262
|
+
self._transform_functions = []
|
|
263
|
+
self._transform_if = True
|
|
264
|
+
self._symbol_tree = stree
|
|
265
|
+
ast_if = self.visit(ast_if)
|
|
266
|
+
ast_if = ast.fix_missing_locations(ast_if)
|
|
267
|
+
return ast_if
|
|
@@ -14,8 +14,12 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Error Log for Rewrite."""
|
|
16
16
|
|
|
17
|
+
import sys
|
|
17
18
|
import ast
|
|
18
|
-
|
|
19
|
+
if sys.version_info >= (3, 9):
|
|
20
|
+
import ast as astunparse # pylint: disable=reimported, ungrouped-imports
|
|
21
|
+
else:
|
|
22
|
+
import astunparse
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
def error_str(reason: str, child_node: ast.expr = None, father_node: ast.expr = None) -> str:
|
mindspore/rewrite/namer.py
CHANGED
|
@@ -14,9 +14,9 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Unique name producer for target, name of node, class name, etc."""
|
|
16
16
|
|
|
17
|
-
from typing import Union
|
|
17
|
+
from typing import Union, Tuple
|
|
18
18
|
|
|
19
|
-
from .node import Node
|
|
19
|
+
from .node.node import Node
|
|
20
20
|
from .api.node_type import NodeType
|
|
21
21
|
|
|
22
22
|
|
|
@@ -33,7 +33,7 @@ class Namer:
|
|
|
33
33
|
self._names: {str: int} = {}
|
|
34
34
|
|
|
35
35
|
@staticmethod
|
|
36
|
-
def _real_name(name: str) -> str:
|
|
36
|
+
def _real_name(name: str) -> Tuple[str, int]:
|
|
37
37
|
"""
|
|
38
38
|
Find real name. For example, "name1" is the real name of "name1_10", "name1" is the real name of "name1_10_3".
|
|
39
39
|
If not find real name before find unique name, unique name may be not unique. For example:
|
|
@@ -47,21 +47,21 @@ class Namer:
|
|
|
47
47
|
name (str): Origin name which may have digit prefix.
|
|
48
48
|
|
|
49
49
|
Returns:
|
|
50
|
-
A string represents real-name.
|
|
50
|
+
A string represents real-name and a int represents suffix.
|
|
51
51
|
"""
|
|
52
52
|
if name == '_':
|
|
53
|
-
return name
|
|
53
|
+
return name, None
|
|
54
54
|
pos = name.rfind("_")
|
|
55
|
-
if pos == -1:
|
|
56
|
-
return name
|
|
55
|
+
if pos == -1 or pos == len(name) - 1:
|
|
56
|
+
return name, None
|
|
57
57
|
digit = True
|
|
58
58
|
for i in range(pos + 1, len(name)):
|
|
59
59
|
if not name[i].isdigit():
|
|
60
60
|
digit = False
|
|
61
61
|
break
|
|
62
62
|
if digit:
|
|
63
|
-
return
|
|
64
|
-
return name
|
|
63
|
+
return name[:pos], int(name[pos + 1:])
|
|
64
|
+
return name, None
|
|
65
65
|
|
|
66
66
|
def get_name(self, origin_name: str) -> str:
|
|
67
67
|
"""
|
|
@@ -75,13 +75,31 @@ class Namer:
|
|
|
75
75
|
"""
|
|
76
76
|
if origin_name == '_':
|
|
77
77
|
return origin_name
|
|
78
|
-
|
|
79
|
-
|
|
78
|
+
real_name, suffix_idx = Namer._real_name(origin_name)
|
|
79
|
+
name = origin_name
|
|
80
|
+
number = self._names.get(name)
|
|
80
81
|
if number is None:
|
|
81
|
-
self._names[
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
self._names[name] = 1
|
|
83
|
+
if not suffix_idx:
|
|
84
|
+
# When _names is {x:2} and origin_name is y,
|
|
85
|
+
# origin_name is not in _names and can be returned.
|
|
86
|
+
return name
|
|
87
|
+
if suffix_idx and not self._names.get(real_name, -1) >= suffix_idx:
|
|
88
|
+
# When _names is {x:2} and origin_name is x_3,
|
|
89
|
+
# return x_3 and update _names to {x:2, x_3:1}
|
|
90
|
+
return name
|
|
91
|
+
# When _names is {x:2} and origin_name is x_1,
|
|
92
|
+
# set new_name to x_1_1 by set number to 1, and continue to update name.
|
|
93
|
+
number = 1
|
|
94
|
+
while True:
|
|
95
|
+
new_name = f"{name}_{number}"
|
|
96
|
+
number += 1
|
|
97
|
+
self._names[name] = number
|
|
98
|
+
# When _names is {x:2, x_3:1}, origin_name is x and number is update to 3,
|
|
99
|
+
# new_name x_3 is conflict with key x_3, so this new_name need to be skipped.
|
|
100
|
+
if new_name in self._names.keys():
|
|
101
|
+
continue
|
|
102
|
+
return new_name
|
|
85
103
|
|
|
86
104
|
def add_name(self, name: str):
|
|
87
105
|
"""
|
|
@@ -93,44 +111,25 @@ class Namer:
|
|
|
93
111
|
Raises:
|
|
94
112
|
RuntimeError: If name is not unique in current namer.
|
|
95
113
|
"""
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
if number is not None:
|
|
99
|
-
raise RuntimeError("name duplicated: ", name)
|
|
100
|
-
self._names[name] = 1
|
|
114
|
+
if self._names.get(name) is None:
|
|
115
|
+
self._names[name] = 1
|
|
101
116
|
|
|
102
117
|
|
|
103
118
|
class TargetNamer(Namer):
|
|
104
119
|
"""
|
|
105
120
|
Used for unique-ing targets of node.
|
|
106
121
|
"""
|
|
107
|
-
def
|
|
108
|
-
super().__init__()
|
|
109
|
-
self._origin_name_map = {}
|
|
110
|
-
|
|
111
|
-
def get_name(self, origin_name: str) -> str:
|
|
112
|
-
ret = super(TargetNamer, self).get_name(origin_name)
|
|
113
|
-
self._origin_name_map[origin_name] = ret
|
|
114
|
-
return ret
|
|
115
|
-
|
|
116
|
-
def add_name(self, name: str):
|
|
117
|
-
super(TargetNamer, self).add_name(name)
|
|
118
|
-
self._origin_name_map[name] = name
|
|
119
|
-
|
|
120
|
-
def get_real_arg(self, origin_arg: str) -> str:
|
|
122
|
+
def get_unique_name(self, origin_name: str) -> str:
|
|
121
123
|
"""
|
|
122
|
-
Get
|
|
124
|
+
Get unique name from 'origin_name'.
|
|
123
125
|
|
|
124
126
|
Args:
|
|
125
|
-
|
|
127
|
+
origin_name (str): Origin name which may be duplicated.
|
|
126
128
|
|
|
127
129
|
Returns:
|
|
128
|
-
A string represents
|
|
130
|
+
A string represents unique-name.
|
|
129
131
|
"""
|
|
130
|
-
|
|
131
|
-
if real_arg:
|
|
132
|
-
return real_arg
|
|
133
|
-
return origin_arg
|
|
132
|
+
return super(TargetNamer, self).get_name(origin_name)
|
|
134
133
|
|
|
135
134
|
|
|
136
135
|
class NodeNamer(Namer):
|
|
@@ -153,19 +152,20 @@ class NodeNamer(Namer):
|
|
|
153
152
|
if isinstance(node_or_name, Node):
|
|
154
153
|
origin_name = node_or_name.get_name()
|
|
155
154
|
if origin_name is None or not origin_name:
|
|
156
|
-
if node_or_name.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.CallFunction
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
targets = node_or_name.get_targets()
|
|
160
|
-
# return node and head node will not call this method
|
|
161
|
-
if not targets:
|
|
162
|
-
raise RuntimeError("node should has at lease one target except return-node and head-node: ",
|
|
163
|
-
node_or_name)
|
|
164
|
-
origin_name = str(targets[0].value)
|
|
155
|
+
if node_or_name.get_node_type() in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.CallFunction,
|
|
156
|
+
NodeType.Tree):
|
|
157
|
+
origin_name = type(node_or_name.get_instance()).__name__
|
|
165
158
|
elif node_or_name.get_node_type() == NodeType.Python:
|
|
166
|
-
|
|
159
|
+
if node_or_name.get_instance():
|
|
160
|
+
origin_name = type(node_or_name.get_instance()).__name__
|
|
161
|
+
else:
|
|
162
|
+
origin_name = "python_node"
|
|
167
163
|
elif node_or_name.get_node_type() == NodeType.Input:
|
|
168
164
|
origin_name = "parameter"
|
|
165
|
+
elif node_or_name.get_node_type() == NodeType.Output:
|
|
166
|
+
origin_name = "return"
|
|
167
|
+
elif node_or_name.get_node_type() == NodeType.MathOps:
|
|
168
|
+
origin_name = "math_ops"
|
|
169
169
|
else:
|
|
170
170
|
raise RuntimeError("Node type unsupported:", node_or_name.get_node_type())
|
|
171
171
|
elif isinstance(node_or_name, str):
|
mindspore/rewrite/namespace.py
CHANGED
|
@@ -21,12 +21,21 @@ _ms_nn_ns = CellNamespace('mindspore.nn')
|
|
|
21
21
|
_ms_ops_ns = CellNamespace('mindspore.ops.operations')
|
|
22
22
|
_ms_functional_ns = CellNamespace('mindspore.ops.functional')
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
24
|
+
# Elements in _subtree_black_list will not be converted to symbol tree.
|
|
25
|
+
# Only str and types are stored in _subtree_black_list.
|
|
26
|
+
_subtree_black_list = ["QuantizeWrapperCell",]
|
|
27
|
+
|
|
28
|
+
def is_subtree(cls_inst):
|
|
29
|
+
"""Determine whether 'cls_inst' is a subtree."""
|
|
30
|
+
cls_name = type(cls_inst).__name__
|
|
31
|
+
black_list_types = tuple([elem for elem in _subtree_black_list if not isinstance(elem, str)])
|
|
32
|
+
if cls_name in _subtree_black_list or isinstance(cls_inst, black_list_types):
|
|
33
|
+
return False
|
|
34
|
+
if cls_name in _ms_common_ns and isinstance(cls_inst, _ms_common_ns[cls_name]):
|
|
35
|
+
return False
|
|
36
|
+
if cls_name in _ms_nn_ns and isinstance(cls_inst, _ms_nn_ns[cls_name]):
|
|
28
37
|
return False
|
|
29
|
-
if cls_name in
|
|
38
|
+
if cls_name in _ms_ops_ns and isinstance(cls_inst, _ms_ops_ns[cls_name]):
|
|
30
39
|
return False
|
|
31
40
|
|
|
32
41
|
return True
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
"""
|
|
16
|
+
SymbolTree node
|
|
17
|
+
"""
|
|
18
|
+
from mindspore.rewrite.node.node import Node, TreeNode
|
|
19
|
+
from mindspore.rewrite.node.node_manager import NodeManager
|
|
20
|
+
from mindspore.rewrite.node.call_function import CallFunction
|
|
21
|
+
from mindspore.rewrite.node.cell_container import CellContainer
|
|
22
|
+
from mindspore.rewrite.node.control_flow import ControlFlow
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""CallFunction Node."""
|
|
16
|
+
import ast
|
|
17
|
+
from .node import Node
|
|
18
|
+
from .node_manager import NodeManager
|
|
19
|
+
from ..api.scoped_value import ScopedValue
|
|
20
|
+
from ..api.node_type import NodeType
|
|
21
|
+
from ..ast_helpers import AstModifier
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CallFunction(Node, NodeManager):
|
|
25
|
+
"""CallFunction is used for class internal function."""
|
|
26
|
+
def __init__(self, targets: [ScopedValue], func_name: ScopedValue, args: [ScopedValue],
|
|
27
|
+
kwargs: {str: ScopedValue}, node_name: str, ast_node: ast.AST, ast_functiondef: ast.FunctionDef,
|
|
28
|
+
stree, instance):
|
|
29
|
+
"""
|
|
30
|
+
Constructor of CallFunction.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
34
|
+
args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
35
|
+
kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class.
|
|
36
|
+
func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class.
|
|
37
|
+
node_name (str): A string represents name of node. Name of node will be unique when inserted into
|
|
38
|
+
SymbolTree. Name of node also used as field name in network class.
|
|
39
|
+
ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast.
|
|
40
|
+
ast_functiondef (ast.FunctionDef): An instance of ast.FunctionDef represents corresponding function
|
|
41
|
+
definition in ast.
|
|
42
|
+
stree (SymbolTree): Symbol tree used to get node_namer.
|
|
43
|
+
instance: Object in network corresponding to this node.
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(func_name, str):
|
|
46
|
+
func_name = ScopedValue.create_naming_value(func_name)
|
|
47
|
+
Node.__init__(self, NodeType.CallFunction, ast_node, targets, func_name, args, kwargs, node_name, instance)
|
|
48
|
+
NodeManager.__init__(self, stree.get_node_namer())
|
|
49
|
+
NodeManager.set_ast_functiondef(self, ast_functiondef)
|
|
50
|
+
NodeManager.set_manager_name(self, func_name.value)
|
|
51
|
+
|
|
52
|
+
def erase_node(self, node):
|
|
53
|
+
"""Erase node from CallFunction."""
|
|
54
|
+
NodeManager.erase_node(self, node)
|
|
55
|
+
# erase asts
|
|
56
|
+
ret = AstModifier.erase_ast_from_function(self.get_ast_functiondef(), node.get_ast())
|
|
57
|
+
if not ret:
|
|
58
|
+
raise ValueError(f"erase node failed, node {node.get_name()} not in function ast tree.")
|
|
59
|
+
|
|
60
|
+
def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True):
|
|
61
|
+
"""
|
|
62
|
+
Insert a node before or after base_node.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
new_node (Node): Node to be inserted.
|
|
66
|
+
base_node (Node): New node will be inserted before or after base_node.
|
|
67
|
+
before_node (bool): Indicate whether new node is inserted before base_node.
|
|
68
|
+
insert_to_ast (bool): Indicate whether ast nodes need to be updated.
|
|
69
|
+
"""
|
|
70
|
+
NodeManager.insert_node(self, new_node, base_node, before_node)
|
|
71
|
+
if insert_to_ast:
|
|
72
|
+
stree = self.get_belong_symbol_tree()
|
|
73
|
+
stree.insert_to_ast_while_insert_node(new_node, base_node, before_node, self)
|
|
74
|
+
|
|
75
|
+
def set_belong_symbol_tree(self, symbol_tree):
|
|
76
|
+
"""Set the symbol tree to which node belongs."""
|
|
77
|
+
self._belong_tree = symbol_tree
|
|
78
|
+
for node in self.nodes():
|
|
79
|
+
node.set_belong_symbol_tree(symbol_tree)
|