mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -13,8 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Rewrite module api: SymbolTree."""
|
|
16
|
-
from typing import Optional
|
|
17
|
-
from types import FunctionType
|
|
16
|
+
from typing import Optional, Union, List
|
|
18
17
|
import mindspore as ms
|
|
19
18
|
|
|
20
19
|
from mindspore.nn import Cell
|
|
@@ -29,10 +28,17 @@ MsDtypes = (ms.float16, ms.float32, ms.float64)
|
|
|
29
28
|
|
|
30
29
|
class SymbolTree:
|
|
31
30
|
"""
|
|
32
|
-
|
|
31
|
+
SymbolTree stores information about a network, including statements of the network's forward
|
|
32
|
+
computation process and the topological relationship between statement input and output.
|
|
33
|
+
|
|
34
|
+
The statements in the network are saved in the SymbolTree in the form of nodes, and by processing
|
|
35
|
+
the nodes in the SymbolTree, you can delete the network code, insert and replace it, and get the
|
|
36
|
+
modified network code and network instances.
|
|
33
37
|
|
|
34
38
|
Args:
|
|
35
|
-
handler (SymbolTreeImpl): SymbolTree internal implementation instance.
|
|
39
|
+
handler (SymbolTreeImpl): SymbolTree internal implementation instance. It is recommended to call the `create`
|
|
40
|
+
method in SymbolTree to create a SymbolTree, rather than calling SymbolTree's constructor directly.
|
|
41
|
+
Don't care what `SymbolTreeImpl` is, just treat it as a handle.
|
|
36
42
|
"""
|
|
37
43
|
|
|
38
44
|
def __init__(self, handler: SymbolTreeImpl):
|
|
@@ -42,16 +48,83 @@ class SymbolTree:
|
|
|
42
48
|
@classmethod
|
|
43
49
|
def create(cls, network):
|
|
44
50
|
"""
|
|
45
|
-
Create a
|
|
51
|
+
Create a SymbolTree object by passing in the network instance `network`.
|
|
52
|
+
|
|
53
|
+
This interface parses the `network` instance, expands each source
|
|
54
|
+
code statement of the forward computation process, and parses it into nodes,
|
|
55
|
+
which is stored in the SymbolTree. The specific process is as follows:
|
|
56
|
+
|
|
57
|
+
1. Obtain the source code of the network instance.
|
|
58
|
+
2. Perform AST parsing on the network and obtain the AST nodes (abstract syntax trees) of each
|
|
59
|
+
statement in the network.
|
|
60
|
+
3. Expand complex statements in the network forward evaluation process into multiple simple statements.
|
|
61
|
+
4. Create a SymbolTree object. Each SymbolTree corresponds to one network instance.
|
|
62
|
+
5. Use the rewrite node to store each statement of the network forward computation process. The node records
|
|
63
|
+
the input, output, and other information of the statement.
|
|
64
|
+
6. Save the rewrite node to the SymbolTree, and update and maintain the topological connection between
|
|
65
|
+
the nodes.
|
|
66
|
+
7. Return the SymbolTree object corresponding to the network instance.
|
|
67
|
+
|
|
68
|
+
If a user-defined network of type :class:`mindspore.nn.Cell` is called in the forward computation process
|
|
69
|
+
of the network, rewrite will generate a node of type `NodeType.Tree` for the corresponding statement. This
|
|
70
|
+
type of node stores a new SymbolTree, which parses and maintains the node information of the user-defined
|
|
71
|
+
network.
|
|
72
|
+
|
|
73
|
+
If the following types of statements are called in the forward computation process of the network, rewrite
|
|
74
|
+
will parse the internal statements in the statement and generate corresponding nodes:
|
|
75
|
+
|
|
76
|
+
- :class:`mindspore.nn.SequentialCell`
|
|
77
|
+
- Functions within classes
|
|
78
|
+
- Control flow statements, such as `if` statements
|
|
79
|
+
|
|
80
|
+
Note:
|
|
81
|
+
Because the specific execution branch of control flows are still unknown during the rewrite operation
|
|
82
|
+
of the network, no topology information will be established between the nodes inside the control flow
|
|
83
|
+
and the nodes outside.
|
|
84
|
+
Users cannot obtain nodes inside the control flow when they acquire nodes outside the control flow using
|
|
85
|
+
interfaces like :func:`mindspore.rewrite.Node.get_inputs` and :func:`mindspore.rewrite.Node.get_users` .
|
|
86
|
+
Users also cannot obtain nodes outside the control flow, if they use these interfaces inside the control
|
|
87
|
+
flow.
|
|
88
|
+
Therefore, when users modify the network, they need to manually handle the node information inside and
|
|
89
|
+
outside the control flow.
|
|
90
|
+
|
|
91
|
+
The current rewrite module has the following syntax limitations:
|
|
92
|
+
|
|
93
|
+
- Only networks of type :class:`mindspore.nn.Cell` are supported as input to the rewrite module.
|
|
94
|
+
- Parsing assignment statements with multiple output values is not currently supported.
|
|
95
|
+
- Parsing loop statements is not currently supported.
|
|
96
|
+
- Parsing decorator syntax is not currently supported.
|
|
97
|
+
- Parsing class variable syntax is not currently supported. If class variable uses external data,
|
|
98
|
+
the network after rewrite may be missing data.
|
|
99
|
+
- Parsing local classes and embedded classes is not currently supported, that is, the definition
|
|
100
|
+
of classes need to be placed on the outermost layer.
|
|
101
|
+
- Parsing closure syntax is not currently supported, that is, the definition of out-of-class
|
|
102
|
+
functions need to be placed at the outermost layer.
|
|
103
|
+
- Parsing lambda expression syntax is not currently supported.
|
|
104
|
+
|
|
105
|
+
For statements that do not support parsing, rewrite will generate nodes of type `NodeType.Python`
|
|
106
|
+
for corresponding statements to ensure that the network after rewrite can run normally.
|
|
107
|
+
The `Python` node does not support modifying the input and output of statements, and there may be
|
|
108
|
+
a problem between variable names and those generated by the rewrite. In this case, users need to
|
|
109
|
+
adjust the variable names manually.
|
|
46
110
|
|
|
47
111
|
Args:
|
|
48
|
-
network (Cell): `network` used to create
|
|
112
|
+
network (Cell): `network` used to create SymbolTree.
|
|
49
113
|
|
|
50
114
|
Returns:
|
|
51
|
-
Symboltree, a
|
|
115
|
+
Symboltree, a SymbolTree created based on `network`.
|
|
52
116
|
|
|
53
117
|
Raises:
|
|
54
118
|
TypeError: If `network` is not a `Cell` instance.
|
|
119
|
+
|
|
120
|
+
Examples:
|
|
121
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
122
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
123
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
124
|
+
>>> net = LeNet5()
|
|
125
|
+
>>> stree = SymbolTree.create(net)
|
|
126
|
+
>>> print(type(stree))
|
|
127
|
+
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
|
|
55
128
|
"""
|
|
56
129
|
Validator.check_value_type("network", network, [Cell], "SymbolTree")
|
|
57
130
|
return cls(SymbolTreeBuilder(network).build())
|
|
@@ -70,61 +143,75 @@ class SymbolTree:
|
|
|
70
143
|
if v not in MsDtypes and not isinstance(v, ParamTypes):
|
|
71
144
|
raise TypeError(f"For call-function Node, got unsupported kwarg value: {v}, type: {type(v)}")
|
|
72
145
|
|
|
73
|
-
def create_call_function(self, func, targets, *args, **kwargs): # pylint: disable=C0111
|
|
74
|
-
Validator.check_value_type("func", func, [FunctionType], "SymbolTree node")
|
|
75
|
-
Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node")
|
|
76
|
-
args_ = list(args)
|
|
77
|
-
SymbolTree._check_args_type(args_)
|
|
78
|
-
for i, arg in enumerate(args_):
|
|
79
|
-
if isinstance(arg, Node):
|
|
80
|
-
args_[i] = arg.get_handler()
|
|
81
|
-
SymbolTree._check_kwargs_type(kwargs)
|
|
82
|
-
for key, value in kwargs.items():
|
|
83
|
-
if isinstance(value, Node):
|
|
84
|
-
kwargs[key] = value.get_handler()
|
|
85
|
-
return Node(self._symbol_tree._create_call_function(func, targets, args_, kwargs)) # pylint: disable=W0212
|
|
86
|
-
|
|
87
146
|
def get_handler(self) -> SymbolTreeImpl:
|
|
88
147
|
return self._symbol_tree
|
|
89
148
|
|
|
90
|
-
def nodes(self):
|
|
149
|
+
def nodes(self, all_nodes: bool = False):
|
|
91
150
|
"""
|
|
92
|
-
Get
|
|
151
|
+
Get the generator of the node in the current SymbolTree, which is used to iterate
|
|
152
|
+
through the nodes in SymbolTree.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
all_nodes (bool): Get all nodes including nodes in CallFunction node, CellContainer node
|
|
156
|
+
and sub symbol tree. Default: ``False`` .
|
|
93
157
|
|
|
94
158
|
Returns:
|
|
95
|
-
A generator for
|
|
159
|
+
A generator for nodes in SymbolTree.
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
TypeError: If `all_nodes` is not bool.
|
|
96
163
|
|
|
97
164
|
Examples:
|
|
98
165
|
>>> from mindspore.rewrite import SymbolTree
|
|
99
|
-
>>>
|
|
100
|
-
>>>
|
|
166
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
167
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
168
|
+
>>> net = LeNet5()
|
|
101
169
|
>>> stree = SymbolTree.create(net)
|
|
102
|
-
>>> for node in stree.nodes()
|
|
103
|
-
|
|
170
|
+
>>> print([node.get_name() for node in stree.nodes()])
|
|
171
|
+
['input_x', 'Expr', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1',
|
|
172
|
+
'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', 'return']
|
|
104
173
|
"""
|
|
105
|
-
|
|
174
|
+
Validator.check_value_type("all_nodes", all_nodes, [bool], "nodes")
|
|
175
|
+
nodes = self._symbol_tree.all_nodes() if all_nodes else self._symbol_tree.nodes()
|
|
176
|
+
for node in nodes:
|
|
106
177
|
yield Node(node)
|
|
107
178
|
|
|
108
179
|
def get_node(self, node_name: str) -> Optional[Node]:
|
|
180
|
+
"""
|
|
181
|
+
Get the node with the name `node_name` in the SymbolTree.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
node_name (str): The name of node.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Node with name of `node_name` . Return ``None`` if there is no node named `node_name` in SymbolTree.
|
|
188
|
+
|
|
189
|
+
Examples:
|
|
190
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
191
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
192
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
193
|
+
>>> net = LeNet5()
|
|
194
|
+
>>> stree = SymbolTree.create(net)
|
|
195
|
+
>>> node = stree.get_node('conv1')
|
|
196
|
+
>>> print(node.get_name())
|
|
197
|
+
conv1
|
|
198
|
+
"""
|
|
109
199
|
Validator.check_value_type("node_name", node_name, [str], "SymbolTree")
|
|
110
|
-
node_impl = self._symbol_tree.
|
|
200
|
+
node_impl = self._symbol_tree.get_node_from_name(node_name)
|
|
111
201
|
if node_impl is None:
|
|
112
202
|
return None
|
|
113
203
|
return Node(node_impl)
|
|
114
204
|
|
|
115
|
-
def get_inputs(self) -> [Node]:
|
|
205
|
+
def get_inputs(self) -> List[Node]:
|
|
116
206
|
return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()]
|
|
117
207
|
|
|
118
|
-
def before(self, node: Node):
|
|
208
|
+
def before(self, node: Union[Node, str]):
|
|
119
209
|
"""
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
`Position` is used to indicate where to insert node, it indicates position in source code rather than position
|
|
123
|
-
in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as
|
|
124
|
-
an arguments of `insert` api of `SymbolTree`.
|
|
210
|
+
Returns a location information before `node`. The return value of this interface is
|
|
211
|
+
used as a parameter for the insert operation.
|
|
125
212
|
|
|
126
213
|
Args:
|
|
127
|
-
node (Node): Indicate the position before which node. Can be a node or name of node.
|
|
214
|
+
node (Union[Node, str]): Indicate the position before which node. Can be a node or name of node.
|
|
128
215
|
|
|
129
216
|
Returns:
|
|
130
217
|
A `Position` to indicate where to insert node.
|
|
@@ -134,26 +221,26 @@ class SymbolTree:
|
|
|
134
221
|
|
|
135
222
|
Examples:
|
|
136
223
|
>>> from mindspore.rewrite import SymbolTree
|
|
137
|
-
>>>
|
|
138
|
-
>>>
|
|
224
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
225
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
226
|
+
>>> net = LeNet5()
|
|
139
227
|
>>> stree = SymbolTree.create(net)
|
|
140
228
|
>>> for node in stree.nodes():
|
|
141
229
|
... if node.get_name() == "conv1":
|
|
142
230
|
... position = stree.before(node)
|
|
143
231
|
"""
|
|
144
|
-
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
|
145
|
-
|
|
232
|
+
Validator.check_value_type("node", node, [Node, str], "SymbolTree")
|
|
233
|
+
if isinstance(node, Node):
|
|
234
|
+
node = node.get_handler()
|
|
235
|
+
return self._symbol_tree.before(node)
|
|
146
236
|
|
|
147
|
-
def after(self, node: Node):
|
|
237
|
+
def after(self, node: Union[Node, str]):
|
|
148
238
|
"""
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
`Position` is used to indicate where to insert node, it indicates position in source code rather than position
|
|
152
|
-
in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as
|
|
153
|
-
an arguments of `insert` api of `SymbolTree`.
|
|
239
|
+
Returns a location information after `node`. The return value of this interface is
|
|
240
|
+
used as a parameter for the insert operation.
|
|
154
241
|
|
|
155
242
|
Args:
|
|
156
|
-
node (Node): Indicate the position after which node. Can be a node or name of node.
|
|
243
|
+
node (Union[Node, str]): Indicate the position after which node. Can be a node or name of node.
|
|
157
244
|
|
|
158
245
|
Returns:
|
|
159
246
|
A `Position` to indicate where to insert node.
|
|
@@ -163,15 +250,18 @@ class SymbolTree:
|
|
|
163
250
|
|
|
164
251
|
Examples:
|
|
165
252
|
>>> from mindspore.rewrite import SymbolTree
|
|
166
|
-
>>>
|
|
167
|
-
>>>
|
|
253
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
254
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
255
|
+
>>> net = LeNet5()
|
|
168
256
|
>>> stree = SymbolTree.create(net)
|
|
169
257
|
>>> for node in stree.nodes():
|
|
170
258
|
... if node.get_name() == "conv1":
|
|
171
259
|
... position = stree.after(node)
|
|
172
260
|
"""
|
|
173
|
-
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
|
174
|
-
|
|
261
|
+
Validator.check_value_type("node", node, [Node, str], "SymbolTree")
|
|
262
|
+
if isinstance(node, Node):
|
|
263
|
+
node = node.get_handler()
|
|
264
|
+
return self._symbol_tree.after(node)
|
|
175
265
|
|
|
176
266
|
def insert(self, position, node: Node) -> Node:
|
|
177
267
|
"""
|
|
@@ -184,8 +274,7 @@ class SymbolTree:
|
|
|
184
274
|
node (Node): An instance of Node to be inserted.
|
|
185
275
|
|
|
186
276
|
Returns:
|
|
187
|
-
An instance of Node being inserted.
|
|
188
|
-
custom-object in args or kwargs.
|
|
277
|
+
An instance of Node being inserted.
|
|
189
278
|
|
|
190
279
|
Raises:
|
|
191
280
|
RuntimeError: If `position` is not belong to current `SymbolTree`.
|
|
@@ -193,67 +282,64 @@ class SymbolTree:
|
|
|
193
282
|
TypeError: If `node` is not a `Node`.
|
|
194
283
|
|
|
195
284
|
Examples:
|
|
196
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
197
|
-
>>>
|
|
198
|
-
>>>
|
|
285
|
+
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
286
|
+
>>> import mindspore.nn as nn
|
|
287
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
288
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
289
|
+
>>> net = LeNet5()
|
|
199
290
|
>>> stree = SymbolTree.create(net)
|
|
200
291
|
>>> node = stree.get_node("conv1")
|
|
201
292
|
>>> position = stree.after(node)
|
|
202
|
-
>>> new_node =
|
|
293
|
+
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
|
|
294
|
+
... args=[ScopedValue.create_naming_value('x')], name='new_relu')
|
|
203
295
|
>>> stree.insert(position, new_node)
|
|
204
296
|
"""
|
|
205
297
|
Validator.check_value_type("position", position, [Position], "SymbolTree")
|
|
206
298
|
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
|
207
|
-
return Node(self._symbol_tree.insert_node(
|
|
299
|
+
return Node(self._symbol_tree.insert_node(node.get_handler(), position.node, position.before_node))
|
|
208
300
|
|
|
209
|
-
def
|
|
301
|
+
def erase(self, node: Union[Node, str]) -> Optional[Node]:
|
|
210
302
|
"""
|
|
211
|
-
Erase a `node` from rewrite.
|
|
303
|
+
Erase a `node` from rewrite.
|
|
212
304
|
|
|
213
305
|
Args:
|
|
214
|
-
node (Node): A `Node` to be erased. Can be a node or name of node.
|
|
306
|
+
node (Union[Node, str]): A `Node` to be erased. Can be a node or name of node.
|
|
215
307
|
|
|
216
308
|
Returns:
|
|
217
309
|
An instance of `Node` being erased if node is in `SymbolTree` else None.
|
|
218
310
|
|
|
219
311
|
Raises:
|
|
220
|
-
TypeError:
|
|
312
|
+
TypeError: The type of `node` is not Node.
|
|
221
313
|
|
|
222
314
|
Examples:
|
|
223
315
|
>>> from mindspore.rewrite import SymbolTree
|
|
224
|
-
>>>
|
|
225
|
-
>>>
|
|
316
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
317
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
318
|
+
>>> net = LeNet5()
|
|
226
319
|
>>> stree = SymbolTree.create(net)
|
|
227
320
|
>>> node = stree.get_node("conv1")
|
|
228
|
-
>>>
|
|
229
|
-
>>> output_nodes = node.get_users()
|
|
230
|
-
>>> for n in output_nodes:
|
|
231
|
-
... n.set_arg(0, "x")
|
|
232
|
-
>>> stree.erase_node(node)
|
|
321
|
+
>>> stree.erase(node)
|
|
233
322
|
"""
|
|
234
|
-
Validator.check_value_type("node", node, [Node], "SymbolTree")
|
|
235
|
-
|
|
323
|
+
Validator.check_value_type("node", node, [Node, str], "SymbolTree")
|
|
324
|
+
if isinstance(node, Node):
|
|
325
|
+
node = node.get_handler()
|
|
326
|
+
return Node(self._symbol_tree.erase_node(node))
|
|
236
327
|
|
|
237
|
-
def replace(self, old_node: Node, new_nodes: [Node]) -> Node:
|
|
328
|
+
def replace(self, old_node: Node, new_nodes: List[Node]) -> Node:
|
|
238
329
|
"""
|
|
239
|
-
Replace `old_node` with
|
|
330
|
+
Replace the `old_node` with nodes in the `new_nodes` list.
|
|
240
331
|
|
|
241
|
-
|
|
332
|
+
Nodes in `new_nodes` will be inserted into SymbolTree sequentially, and then `old_node` will be deleted.
|
|
242
333
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
4. Caller should maintain arguments of input nodes of sub-tree and for specifying topological relation of
|
|
249
|
-
inputs of sub-tree.
|
|
250
|
-
5. Rewrite will maintain arguments of prepend node of sub-tree for specifying topological relation of
|
|
251
|
-
outputs of sub-tree.
|
|
252
|
-
6. Rewrite will maintain all inputs of nodes after replace `new_nodes` into `SymbolTree`.
|
|
334
|
+
Note:
|
|
335
|
+
- Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi
|
|
336
|
+
replacement, please refer to `PatternEngine`.
|
|
337
|
+
- Caller should maintain the topological relationship between each node in the `new_nodes` , as well as
|
|
338
|
+
the topological relationship between nodes in the `new_nodes` and nodes in the original tree.
|
|
253
339
|
|
|
254
340
|
Args:
|
|
255
341
|
old_node (Node): Node to be replaced.
|
|
256
|
-
new_nodes (
|
|
342
|
+
new_nodes (List[Node]): Nodes of the node_tree to replace in.
|
|
257
343
|
|
|
258
344
|
Returns:
|
|
259
345
|
An instance of Node represents root of node_tree been replaced in.
|
|
@@ -264,12 +350,15 @@ class SymbolTree:
|
|
|
264
350
|
TypeError: If `new_nodes` is not a `list` or node in `new_nodes` is not a `Node`.
|
|
265
351
|
|
|
266
352
|
Examples:
|
|
267
|
-
>>> from mindspore.rewrite import SymbolTree
|
|
268
|
-
>>>
|
|
269
|
-
>>>
|
|
353
|
+
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
354
|
+
>>> import mindspore.nn as nn
|
|
355
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
356
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
357
|
+
>>> net = LeNet5()
|
|
270
358
|
>>> stree = SymbolTree.create(net)
|
|
271
359
|
>>> node = stree.get_node("conv1")
|
|
272
|
-
>>> new_node =
|
|
360
|
+
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
|
|
361
|
+
... args=[ScopedValue.create_naming_value('x')], name='new_relu')
|
|
273
362
|
>>> stree.replace(node, [new_node])
|
|
274
363
|
"""
|
|
275
364
|
Validator.check_value_type("old_node", old_node, [Node], "SymbolTree")
|
|
@@ -283,44 +372,83 @@ class SymbolTree:
|
|
|
283
372
|
return Node(self._symbol_tree.set_output(return_value, index))
|
|
284
373
|
|
|
285
374
|
def dump(self):
|
|
286
|
-
"""
|
|
287
|
-
Print the ir map information corresponding to the network in 'SymbolTree' to the screen.
|
|
288
|
-
"""
|
|
289
375
|
self._symbol_tree.dump()
|
|
290
376
|
|
|
291
|
-
def print_node_tabulate(self):
|
|
292
|
-
|
|
377
|
+
def print_node_tabulate(self, all_nodes: bool = False):
|
|
378
|
+
r"""
|
|
379
|
+
Print the topology information of nodes in SymbolTree, including node type, node name, node code,
|
|
380
|
+
and node input-output relationship.
|
|
381
|
+
|
|
382
|
+
The information is output to the screen using the print interface, including the following information:
|
|
383
|
+
|
|
384
|
+
- **node type** (str): The type of node, refer to class:`mindspore.rewrite.NodeType` .
|
|
385
|
+
- **name** (str): The name of node.
|
|
386
|
+
- **codes** (str): The source code statement corresponding to the node.
|
|
387
|
+
- **arg providers** (Dict[int, Tuple[str, int]]): The format is `{[idx, (n, k)]}` , which means the
|
|
388
|
+
`idx` th parameter of the node is provided by the `k` th output of node `n` .
|
|
389
|
+
- **target users** (Dict[int, List[Tuple[str, int]]]): The format is '{[idx, [(n, k)]]}' , which means
|
|
390
|
+
the `idx` th output of the node is used as the `k` th parameter of node `n` .
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
all_nodes (bool): Print information of all nodes, including nodes in CallFunction
|
|
394
|
+
node, CellContainer node and sub symbol tree. Default: ``False`` .
|
|
395
|
+
|
|
396
|
+
Raises:
|
|
397
|
+
TypeError: If `all_nodes` is not bool.
|
|
398
|
+
|
|
399
|
+
Examples:
|
|
400
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
401
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
402
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
403
|
+
>>> net = LeNet5()
|
|
404
|
+
>>> stree = SymbolTree.create(net)
|
|
405
|
+
>>> stree.print_node_tabulate()
|
|
406
|
+
"""
|
|
407
|
+
Validator.check_value_type("all_nodes", all_nodes, [bool], "print_node_tabulate")
|
|
408
|
+
self._symbol_tree.print_node_tabulate(all_nodes)
|
|
293
409
|
|
|
294
410
|
def get_code(self) -> str:
|
|
295
411
|
"""
|
|
296
|
-
Get source code
|
|
412
|
+
Get source code corresponding to the network information in SymbolTree.
|
|
413
|
+
If the network has already been modified, the source code of modified network is returned.
|
|
297
414
|
|
|
298
415
|
Returns:
|
|
299
416
|
A str represents source code of modified network.
|
|
300
417
|
|
|
301
418
|
Examples:
|
|
302
419
|
>>> from mindspore.rewrite import SymbolTree
|
|
303
|
-
>>>
|
|
304
|
-
>>>
|
|
420
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
421
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
422
|
+
>>> net = LeNet5()
|
|
305
423
|
>>> stree = SymbolTree.create(net)
|
|
306
|
-
>>> stree.get_code()
|
|
424
|
+
>>> codes = stree.get_code()
|
|
425
|
+
>>> print(codes)
|
|
307
426
|
"""
|
|
308
427
|
return self._symbol_tree.get_code()
|
|
309
428
|
|
|
310
429
|
def get_network(self) -> Cell:
|
|
311
430
|
"""
|
|
312
|
-
Get
|
|
313
|
-
The source code
|
|
431
|
+
Get the network object generated based on SymbolTree.
|
|
432
|
+
The source code is saved to a file in the 'rewritten_network' folder of the current directory.
|
|
433
|
+
|
|
434
|
+
Note:
|
|
435
|
+
- The modification of network by rewrite module is based on the modification of AST tree of
|
|
436
|
+
original network instance, and the new network instance will obtain attribute information
|
|
437
|
+
from original network instance, so the new network instance and the original network instance
|
|
438
|
+
have data association, and the original network should no longer be used.
|
|
439
|
+
- Due to the data association between the new network and the original network instance, manually creating
|
|
440
|
+
a network instance using the source code file generated by rewrite is not currently supported.
|
|
314
441
|
|
|
315
442
|
Returns:
|
|
316
|
-
A network object.
|
|
443
|
+
A network object generated from SymbolTree.
|
|
317
444
|
|
|
318
445
|
Examples:
|
|
319
446
|
>>> from mindspore.rewrite import SymbolTree
|
|
320
|
-
>>>
|
|
321
|
-
>>>
|
|
447
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
448
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
449
|
+
>>> net = LeNet5()
|
|
322
450
|
>>> stree = SymbolTree.create(net)
|
|
323
|
-
>>> stree.get_network()
|
|
451
|
+
>>> new_net = stree.get_network()
|
|
324
452
|
"""
|
|
325
453
|
return self._symbol_tree.get_network()
|
|
326
454
|
|
|
@@ -333,3 +461,21 @@ class SymbolTree:
|
|
|
333
461
|
|
|
334
462
|
def save_network_to_file(self):
|
|
335
463
|
self._symbol_tree.save_network_to_file()
|
|
464
|
+
|
|
465
|
+
def unique_name(self, name: str = "output"):
|
|
466
|
+
"""
|
|
467
|
+
Based on the given `name` , returns a new name that is unique within the symbol tree.
|
|
468
|
+
This interface can be used when a variable name that does not conflict is required.
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
name (str, optional): The prefix of the name. Defaults to ``"output"`` .
|
|
472
|
+
|
|
473
|
+
Returns:
|
|
474
|
+
str, A new, unique name within a symbol tree in the format `name_n`, where `n` is a numeric subscript.
|
|
475
|
+
If there is no name conflict when entered `name`, there is no numeric subscript.
|
|
476
|
+
|
|
477
|
+
Raises:
|
|
478
|
+
TypeError: The type of `name` is not str.
|
|
479
|
+
"""
|
|
480
|
+
Validator.check_value_type("name", name, [str], "SymbolTree")
|
|
481
|
+
return self._symbol_tree.unique_name(name)
|
|
@@ -28,6 +28,9 @@ class TreeNodeHelper:
|
|
|
28
28
|
`TreeNodeHelper` is used to break circle reference while getting symbol_tree from a `Tree` type `Node`.
|
|
29
29
|
|
|
30
30
|
`TreeNodeHelper` provides a staticmethod `get_sub_tree` for getting symbol_tree from a `Tree` type `Node`.
|
|
31
|
+
|
|
32
|
+
.. warning::
|
|
33
|
+
This is a set of experimental APIs that is subject to change or deletion.
|
|
31
34
|
"""
|
|
32
35
|
|
|
33
36
|
@staticmethod
|
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
Define some ast helpers for manipulating python ast.
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
|
-
from .ast_finder import AstFinder, StrChecker, CheckPropertyIsUsed, GetPropertyOfObj
|
|
20
|
+
from .ast_finder import AstFinder, StrChecker, CheckPropertyIsUsed, GetPropertyOfObj, \
|
|
21
|
+
AstAssignFinder, AstClassFinder, AstFunctionFinder
|
|
21
22
|
from .ast_replacer import AstReplacer
|
|
22
23
|
from .ast_modifier import AstModifier
|
|
23
24
|
from .ast_creator import ast_args_creator, ast_assign_creator, ast_attributer_creator, ast_call_creator, \
|