mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
mindspore/rewrite/api/node.py
CHANGED
|
@@ -14,12 +14,13 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Rewrite module api: Node."""
|
|
16
16
|
|
|
17
|
-
from typing import Union, Optional
|
|
17
|
+
from typing import Union, Optional, List, Dict
|
|
18
|
+
from types import FunctionType
|
|
18
19
|
|
|
19
20
|
from mindspore.nn import Cell
|
|
20
21
|
from mindspore.ops.primitive import Primitive
|
|
21
22
|
from mindspore import _checkparam as Validator
|
|
22
|
-
from ..node import Node as NodeImpl
|
|
23
|
+
from ..node.node import Node as NodeImpl
|
|
23
24
|
from ..symbol_tree import SymbolTree as SymbolTreeImpl
|
|
24
25
|
from .node_type import NodeType
|
|
25
26
|
from .scoped_value import ScopedValue
|
|
@@ -27,16 +28,17 @@ from .scoped_value import ScopedValue
|
|
|
27
28
|
|
|
28
29
|
class Node:
|
|
29
30
|
"""
|
|
30
|
-
|
|
31
|
+
A node is a data structure that expresses source code statements in a network.
|
|
31
32
|
|
|
32
|
-
|
|
33
|
-
|
|
33
|
+
Each node usually corresponds to a statement in expanded forward evaluation process.
|
|
34
|
+
|
|
35
|
+
Nodes can express a ``Cell`` call statement, a ``Primitive`` call statement, an arithmetic operation statement, a
|
|
36
|
+
return statements, etc. of the forward calculation process.
|
|
34
37
|
|
|
35
38
|
Args:
|
|
36
|
-
node (NodeImpl): A handler of `NodeImpl`.
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
directly, so don't care about what is `NodeImpl` and use its instance just as a handler.
|
|
39
|
+
node (NodeImpl): A handler of `NodeImpl`. It is recommended to call the specific methods in Node to create
|
|
40
|
+
a Node, such as 'create_call_cell', rather than calling the Node's constructor directly.
|
|
41
|
+
Don't care what `NodeImpl` is, just treat it as a handle.
|
|
40
42
|
"""
|
|
41
43
|
|
|
42
44
|
def __init__(self, node: NodeImpl):
|
|
@@ -49,8 +51,8 @@ class Node:
|
|
|
49
51
|
return self._node == other._node
|
|
50
52
|
|
|
51
53
|
@staticmethod
|
|
52
|
-
def create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None,
|
|
53
|
-
kwargs:
|
|
54
|
+
def create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None,
|
|
55
|
+
kwargs: Dict[str, ScopedValue] = None, name: str = "", is_sub_net: bool = False) -> 'Node':
|
|
54
56
|
"""
|
|
55
57
|
Create a node. Only support create from a `Cell` now.
|
|
56
58
|
|
|
@@ -62,20 +64,18 @@ class Node:
|
|
|
62
64
|
|
|
63
65
|
Args:
|
|
64
66
|
cell (Cell): Cell-operator of this forward-layer.
|
|
65
|
-
targets (
|
|
66
|
-
|
|
67
|
-
args (
|
|
68
|
-
source code. Default
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
Default is None indicate the `cell` has no kwargs inputs. Rewrite will check and ensure the uniqueness
|
|
73
|
-
of each kwarg while node being inserted.
|
|
67
|
+
targets (List[Union[ScopedValue, str]]): Indicate output names. Used as targets of an assign statement in
|
|
68
|
+
source code.
|
|
69
|
+
args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
70
|
+
source code. Default: ``None`` , which indicates the `cell` has no args inputs.
|
|
71
|
+
kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
72
|
+
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
73
|
+
code. Default: ``None`` , which indicates the `cell` has no kwargs inputs.
|
|
74
74
|
name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will
|
|
75
|
-
generate name from `
|
|
76
|
-
while node being inserted.
|
|
77
|
-
is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse
|
|
78
|
-
`cell` to a TreeNode,
|
|
75
|
+
generate name from `cell` when name is None. Rewrite will check and ensure the uniqueness of `name`
|
|
76
|
+
while node being inserted. Default: ``""`` .
|
|
77
|
+
is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse
|
|
78
|
+
the `cell` to a TreeNode, otherwise the `cell` is parsed to a CallCell node. Default: ``False`` .
|
|
79
79
|
|
|
80
80
|
Returns:
|
|
81
81
|
An instance of `Node`.
|
|
@@ -86,6 +86,21 @@ class Node:
|
|
|
86
86
|
TypeError: If the type of `targets` is not in `[ScopedValue, str]`.
|
|
87
87
|
TypeError: If arg in `args` is not a `ScopedValue`.
|
|
88
88
|
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`.
|
|
89
|
+
|
|
90
|
+
Examples:
|
|
91
|
+
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
92
|
+
>>> import mindspore.nn as nn
|
|
93
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
94
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
95
|
+
>>> net = LeNet5()
|
|
96
|
+
>>> stree = SymbolTree.create(net)
|
|
97
|
+
>>> node = stree.get_node("conv1")
|
|
98
|
+
>>> position = stree.after(node)
|
|
99
|
+
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
|
|
100
|
+
... args=[ScopedValue.create_naming_value('x')], name='new_relu')
|
|
101
|
+
>>> stree.insert(position, new_node)
|
|
102
|
+
>>> print(type(new_node))
|
|
103
|
+
<class 'mindspore.rewrite.api.node.Node'>
|
|
89
104
|
"""
|
|
90
105
|
Validator.check_value_type("cell", cell, [Cell, Primitive], "Node")
|
|
91
106
|
Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
|
|
@@ -95,35 +110,107 @@ class Node:
|
|
|
95
110
|
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
|
|
96
111
|
if kwargs is not None:
|
|
97
112
|
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
|
|
98
|
-
return Node(NodeImpl.create_call_op(cell, None, targets,
|
|
99
|
-
|
|
113
|
+
return Node(NodeImpl.create_call_op(cell, None, targets, args, kwargs, name, is_sub_net))
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]],
|
|
117
|
+
args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None) -> 'Node':
|
|
118
|
+
"""
|
|
119
|
+
Create a node that corresponds to a function call. The `function` object is saved into network, and used via
|
|
120
|
+
getting object from `self.` .
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
function (FunctionType): The function to be called.
|
|
124
|
+
targets (List[Union[ScopedValue, str]]): indicates output names. Used as targets of an assign statement in
|
|
125
|
+
source code.
|
|
126
|
+
args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
|
|
127
|
+
source code. Default: ``None`` , which indicates the `function` has no args inputs.
|
|
128
|
+
kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
|
|
129
|
+
Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
|
|
130
|
+
code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
An instance of `Node`.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
TypeError: If `function` is not a `FunctionType`.
|
|
137
|
+
TypeError: If `targets` is not `list`.
|
|
138
|
+
TypeError: If the type of `targets` is not in `[ScopedValue, str]`.
|
|
139
|
+
TypeError: If arg in `args` is not a `ScopedValue`.
|
|
140
|
+
TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`.
|
|
141
|
+
|
|
142
|
+
Examples:
|
|
143
|
+
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
144
|
+
>>> import mindspore.nn as nn
|
|
145
|
+
>>> import mindspore.ops as ops
|
|
146
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
147
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
148
|
+
>>> net = LeNet5()
|
|
149
|
+
>>> stree = SymbolTree.create(net)
|
|
150
|
+
>>> node = stree.get_node("conv1")
|
|
151
|
+
>>> position = stree.after(node)
|
|
152
|
+
>>> new_node = node.create_call_function(function=ops.abs, targets=['x'],
|
|
153
|
+
... args=[ScopedValue.create_naming_value('x')])
|
|
154
|
+
>>> stree.insert(position, new_node)
|
|
155
|
+
>>> print(new_node.get_node_type())
|
|
156
|
+
NodeType.CallFunction
|
|
157
|
+
"""
|
|
158
|
+
Validator.check_value_type("function", function, [FunctionType, type], "create_call_function")
|
|
159
|
+
Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "create_call_function")
|
|
160
|
+
if args is not None:
|
|
161
|
+
Validator.check_element_type_of_iterable("args", args, [ScopedValue], "create_call_function")
|
|
162
|
+
if kwargs is not None:
|
|
163
|
+
Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "create_call_function")
|
|
164
|
+
return Node(NodeImpl._create_call_function(function, targets, args, kwargs))
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def create_input(param_name: str, default: Optional[ScopedValue] = None) -> 'Node':
|
|
168
|
+
# pylint: disable=missing-function-docstring
|
|
169
|
+
Validator.check_value_type("param_name", param_name, [str], "Node")
|
|
170
|
+
if default is not None:
|
|
171
|
+
Validator.check_value_type("default", default, [ScopedValue], "Node")
|
|
172
|
+
return Node(NodeImpl.create_input_node(None, param_name, default, name=f"input_{param_name}"))
|
|
100
173
|
|
|
101
174
|
def get_handler(self) -> NodeImpl:
|
|
102
175
|
return self._node
|
|
103
176
|
|
|
104
177
|
def get_inputs(self) -> ['Node']:
|
|
105
178
|
"""
|
|
106
|
-
|
|
179
|
+
Gets a list of nodes whose output values are used as input values for the current node.
|
|
107
180
|
|
|
108
181
|
Returns:
|
|
109
|
-
A list of
|
|
182
|
+
A list of nodes.
|
|
183
|
+
|
|
184
|
+
Examples:
|
|
185
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
186
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
187
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
188
|
+
>>> net = LeNet5()
|
|
189
|
+
>>> stree = SymbolTree.create(net)
|
|
190
|
+
>>> node = stree.get_node("conv2")
|
|
191
|
+
>>> inputs = node.get_inputs()
|
|
192
|
+
>>> print([input.get_name() for input in inputs])
|
|
193
|
+
['max_pool2d']
|
|
110
194
|
"""
|
|
111
195
|
return [Node(node_impl) for node_impl in self._node.get_inputs()]
|
|
112
196
|
|
|
113
197
|
def get_users(self) -> ['Node']:
|
|
114
198
|
"""
|
|
115
|
-
Get
|
|
199
|
+
Get a list of nodes that use the output of the current node as input.
|
|
116
200
|
|
|
117
201
|
Returns:
|
|
118
|
-
A list of nodes
|
|
202
|
+
A list of nodes.
|
|
119
203
|
|
|
120
204
|
Examples:
|
|
121
205
|
>>> from mindspore.rewrite import SymbolTree
|
|
122
|
-
>>>
|
|
123
|
-
>>>
|
|
206
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
207
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
208
|
+
>>> net = LeNet5()
|
|
124
209
|
>>> stree = SymbolTree.create(net)
|
|
125
210
|
>>> node = stree.get_node("conv1")
|
|
126
211
|
>>> users = node.get_users()
|
|
212
|
+
>>> print([user.get_name() for user in users])
|
|
213
|
+
['relu']
|
|
127
214
|
"""
|
|
128
215
|
belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
|
|
129
216
|
if belong_symbol_tree is None:
|
|
@@ -149,11 +236,14 @@ class Node:
|
|
|
149
236
|
|
|
150
237
|
Examples:
|
|
151
238
|
>>> from mindspore.rewrite import SymbolTree
|
|
152
|
-
>>>
|
|
153
|
-
>>>
|
|
239
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
240
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
241
|
+
>>> net = LeNet5()
|
|
154
242
|
>>> stree = SymbolTree.create(net)
|
|
155
|
-
>>> node = stree.get_node("
|
|
156
|
-
>>> node.set_arg(0, "
|
|
243
|
+
>>> node = stree.get_node("relu_3")
|
|
244
|
+
>>> node.set_arg(0, "fc1")
|
|
245
|
+
>>> print(node.get_args())
|
|
246
|
+
[fc1]
|
|
157
247
|
"""
|
|
158
248
|
Validator.check_value_type("index", index, [int], "Node")
|
|
159
249
|
Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
|
|
@@ -170,7 +260,8 @@ class Node:
|
|
|
170
260
|
Args:
|
|
171
261
|
arg_idx (int): Indicate which input being modified.
|
|
172
262
|
src_node (Node): A `Node` as new input. Can be a node or name of node.
|
|
173
|
-
out_idx (int, optional): Indicate which output of `src_node` as new input of current node.
|
|
263
|
+
out_idx (int, optional): Indicate which output of `src_node` as new input of current node.
|
|
264
|
+
Default: ``None`` ,
|
|
174
265
|
which means use first output of `src_node` as new input.
|
|
175
266
|
|
|
176
267
|
Raises:
|
|
@@ -184,12 +275,15 @@ class Node:
|
|
|
184
275
|
|
|
185
276
|
Examples:
|
|
186
277
|
>>> from mindspore.rewrite import SymbolTree
|
|
187
|
-
>>>
|
|
188
|
-
>>>
|
|
278
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
279
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
280
|
+
>>> net = LeNet5()
|
|
189
281
|
>>> stree = SymbolTree.create(net)
|
|
190
|
-
>>> src_node = stree.get_node("
|
|
191
|
-
>>> dst_node = stree.get_node("
|
|
192
|
-
>>> dst_node.set_arg_by_node(0, src_node)
|
|
282
|
+
>>> src_node = stree.get_node("fc1")
|
|
283
|
+
>>> dst_node = stree.get_node("relu_3")
|
|
284
|
+
>>> dst_node.set_arg_by_node(0, src_node, 0)
|
|
285
|
+
>>> print(dst_node.get_args())
|
|
286
|
+
[fc1]
|
|
193
287
|
"""
|
|
194
288
|
Validator.check_value_type("arg_idx", arg_idx, [int], "Node")
|
|
195
289
|
Validator.check_value_type("src_node", src_node, [Node], "Node")
|
|
@@ -202,6 +296,12 @@ class Node:
|
|
|
202
296
|
belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node.get_handler(), out_idx)
|
|
203
297
|
|
|
204
298
|
def get_targets(self) -> [ScopedValue]:
|
|
299
|
+
"""
|
|
300
|
+
Gets a list of output values for the current node.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
A list of outputs of type ``ScopedValue`` .
|
|
304
|
+
"""
|
|
205
305
|
return self._node.get_targets()
|
|
206
306
|
|
|
207
307
|
def get_name(self) -> str:
|
|
@@ -215,43 +315,61 @@ class Node:
|
|
|
215
315
|
|
|
216
316
|
Examples:
|
|
217
317
|
>>> from mindspore.rewrite import SymbolTree
|
|
218
|
-
>>>
|
|
219
|
-
>>>
|
|
318
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
319
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
320
|
+
>>> net = LeNet5()
|
|
220
321
|
>>> stree = SymbolTree.create(net)
|
|
221
322
|
>>> node = stree.get_node("conv1")
|
|
222
323
|
>>> name = node.get_name()
|
|
324
|
+
>>> print(name)
|
|
325
|
+
conv1
|
|
223
326
|
"""
|
|
224
327
|
return self._node.get_name()
|
|
225
328
|
|
|
226
329
|
def get_node_type(self) -> NodeType:
|
|
227
330
|
"""
|
|
228
|
-
Get the node_type of current node.
|
|
331
|
+
Get the node_type of current node. See :class:`mindspore.rewrite.NodeType` for details on node types.
|
|
229
332
|
|
|
230
333
|
Returns:
|
|
231
334
|
A NodeType as node_type of node.
|
|
232
335
|
|
|
233
336
|
Examples:
|
|
234
337
|
>>> from mindspore.rewrite import SymbolTree
|
|
235
|
-
>>>
|
|
236
|
-
>>>
|
|
338
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
339
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
340
|
+
>>> net = LeNet5()
|
|
237
341
|
>>> stree = SymbolTree.create(net)
|
|
238
342
|
>>> node = stree.get_node("conv1")
|
|
239
343
|
>>> node_type = node.get_node_type()
|
|
344
|
+
>>> print(node_type)
|
|
345
|
+
NodeType.CallCell
|
|
240
346
|
"""
|
|
241
347
|
return self._node.get_node_type()
|
|
242
348
|
|
|
243
349
|
def get_instance_type(self) -> type:
|
|
244
350
|
"""
|
|
245
|
-
|
|
351
|
+
Gets the instance type called in the code corresponding to the current node.
|
|
246
352
|
|
|
247
|
-
- When node_type of current node is `CallCell`,
|
|
248
|
-
- When node_type of current node is `CallPrimitive`,
|
|
249
|
-
|
|
250
|
-
- When node_type of current node is `
|
|
251
|
-
|
|
353
|
+
- When `node_type` of current node is `CallCell`, the code for that node calls an instance of type ``Cell`` .
|
|
354
|
+
- When `node_type` of current node is `CallPrimitive`, the code for that node calls an instance of
|
|
355
|
+
type ``Primitive`` .
|
|
356
|
+
- When `node_type` of current node is `Tree`, the code for that node calls an instance of network type.
|
|
357
|
+
- When `node_type` of current node is `Python`, `Input`, `Output` or `CallMethod`, the instance type
|
|
358
|
+
is ``NoneType`` .
|
|
252
359
|
|
|
253
360
|
Returns:
|
|
254
|
-
|
|
361
|
+
The type of instance called in the statement corresponding to the current node.
|
|
362
|
+
|
|
363
|
+
Examples:
|
|
364
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
365
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
366
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
367
|
+
>>> net = LeNet5()
|
|
368
|
+
>>> stree = SymbolTree.create(net)
|
|
369
|
+
>>> node = stree.get_node("conv1")
|
|
370
|
+
>>> instance_type = node.get_instance_type()
|
|
371
|
+
>>> print(instance_type)
|
|
372
|
+
<class 'mindspore.nn.layer.conv.Conv2d'>
|
|
255
373
|
"""
|
|
256
374
|
return self._node.get_instance_type()
|
|
257
375
|
|
|
@@ -259,8 +377,47 @@ class Node:
|
|
|
259
377
|
return self._node.get_instance()
|
|
260
378
|
|
|
261
379
|
def get_args(self) -> [ScopedValue]:
|
|
380
|
+
"""
|
|
381
|
+
Get arguments of current node.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
A list of arguments of type ``ScopedValue`` .
|
|
385
|
+
|
|
386
|
+
Examples:
|
|
387
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
388
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
389
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
390
|
+
>>> net = LeNet5()
|
|
391
|
+
>>> stree = SymbolTree.create(net)
|
|
392
|
+
>>> node = stree.get_node("conv1")
|
|
393
|
+
>>> print(node.get_args())
|
|
394
|
+
[x]
|
|
395
|
+
"""
|
|
262
396
|
return self._node.get_args()
|
|
263
397
|
|
|
398
|
+
def get_symbol_tree(self) -> 'SymbolTree':
|
|
399
|
+
"""
|
|
400
|
+
Get the symbol tree which current node belongs to.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
SymbolTree, None if current node does not belong to any SymbolTree.
|
|
404
|
+
|
|
405
|
+
Examples:
|
|
406
|
+
>>> from mindspore.rewrite import SymbolTree
|
|
407
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
408
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
409
|
+
>>> net = LeNet5()
|
|
410
|
+
>>> stree = SymbolTree.create(net)
|
|
411
|
+
>>> node = stree.get_node("conv1")
|
|
412
|
+
>>> print(type(node.get_symbol_tree()))
|
|
413
|
+
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
|
|
414
|
+
"""
|
|
415
|
+
from .symbol_tree import SymbolTree
|
|
416
|
+
stree_impl = self._node.get_belong_symbol_tree()
|
|
417
|
+
if not stree_impl:
|
|
418
|
+
return None
|
|
419
|
+
return SymbolTree(stree_impl)
|
|
420
|
+
|
|
264
421
|
def get_kwargs(self) -> {str: ScopedValue}:
|
|
265
422
|
return self._node.get_kwargs()
|
|
266
423
|
|
|
@@ -23,13 +23,17 @@ class NodeType(Enum):
|
|
|
23
23
|
- Unknown: Not inited NodeType.
|
|
24
24
|
- CallCell: `CallCell` node represents invoking cell-op in forward method.
|
|
25
25
|
- CallPrimitive: `CallPrimitive` node represents invoking primitive-op in forward method.
|
|
26
|
+
- CallFunction: `CallFunction` node represents invoking a function in forward method.
|
|
26
27
|
- CallMethod: `CallMethod` node represents invoking of method in forward method which can not be mapped to
|
|
27
28
|
cell-op or primitive-op in MindSpore.
|
|
28
29
|
- Python: `Python` node holds unsupported-ast-node or unnecessary-to-parse-ast-node.
|
|
29
30
|
- Input: `Input` node represents input of `SymbolTree` corresponding to arguments of forward method.
|
|
30
31
|
- Output: `Output` node represents output of SymbolTree corresponding to return statement of forward method.
|
|
31
32
|
- Tree: `Tree` node represents sub-network invoking in forward method.
|
|
33
|
+
- CellContainer: `CellContainer` node represents invoking method :class:`mindspore.nn.SequentialCell` in
|
|
34
|
+
forward method.
|
|
32
35
|
- MathOps: `MathOps` node represents a mathematical operation, such as adding or comparing in forward method.
|
|
36
|
+
- ControlFlow: `ControlFlow` node represents a control flow statement, such as if statement.
|
|
33
37
|
|
|
34
38
|
"""
|
|
35
39
|
Unknown = 0
|
|
@@ -46,3 +50,4 @@ class NodeType(Enum):
|
|
|
46
50
|
Tree = 9
|
|
47
51
|
CellContainer = 10
|
|
48
52
|
MathOps = 11
|
|
53
|
+
ControlFlow = 12
|
|
@@ -31,10 +31,13 @@ class PatternNode:
|
|
|
31
31
|
"""
|
|
32
32
|
`PatternNode` is defined as a node while defining pattern.
|
|
33
33
|
|
|
34
|
+
.. warning::
|
|
35
|
+
This is a set of experimental APIs that is subject to change or deletion.
|
|
36
|
+
|
|
34
37
|
Args:
|
|
35
38
|
pattern_node_name (str): Name of current node.
|
|
36
|
-
match_type (Type): A type represents what type would be matched of current node. Default: None.
|
|
37
|
-
inputs (list[PatternNode]): Input nodes of current node. Default: None.
|
|
39
|
+
match_type (Type): A type represents what type would be matched of current node. Default: ``Type[None]`` .
|
|
40
|
+
inputs (list[PatternNode]): Input nodes of current node. Default: ``None`` .
|
|
38
41
|
"""
|
|
39
42
|
|
|
40
43
|
def __init__(self, pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None):
|
|
@@ -180,6 +183,9 @@ class PatternNode:
|
|
|
180
183
|
class VarNode(PatternNode):
|
|
181
184
|
"""
|
|
182
185
|
VarNode is a subclass of `PatternNode` whose `match` method is always return True.
|
|
186
|
+
|
|
187
|
+
.. warning::
|
|
188
|
+
This is a set of experimental APIs that is subject to change or deletion.
|
|
183
189
|
"""
|
|
184
190
|
|
|
185
191
|
def __init__(self):
|
|
@@ -193,9 +199,12 @@ class Replacement(abc.ABC):
|
|
|
193
199
|
"""
|
|
194
200
|
Interface of replacement function.
|
|
195
201
|
|
|
202
|
+
.. warning::
|
|
203
|
+
This is a set of experimental APIs that is subject to change or deletion.
|
|
204
|
+
|
|
196
205
|
Examples:
|
|
197
206
|
>>> from mindspore.rewrite import Replacement, Node
|
|
198
|
-
>>>
|
|
207
|
+
>>> import mindspore.nn as nn
|
|
199
208
|
>>> class BnReplacement(Replacement):
|
|
200
209
|
... def build(self, pattern, is_chain_pattern: bool, matched):
|
|
201
210
|
... bn_node: Node = matched.get(pattern.name())
|
|
@@ -232,10 +241,13 @@ class PatternEngine:
|
|
|
232
241
|
"""
|
|
233
242
|
`PatternEngine` is defined how to transform a `SymbolTree` by `PattenNode`.
|
|
234
243
|
|
|
244
|
+
.. warning::
|
|
245
|
+
This is a set of experimental APIs that is subject to change or deletion.
|
|
246
|
+
|
|
235
247
|
Args:
|
|
236
248
|
pattern (Union[PatternNode, List]): An instance of `PatternNode` or a cell-type-list to construct `PatternNode`
|
|
237
249
|
as root of a pattern.
|
|
238
|
-
replacement (callable): A callable define how to generate new_node.
|
|
250
|
+
replacement (callable): A callable define how to generate new_node. Default: ``None`` .
|
|
239
251
|
"""
|
|
240
252
|
|
|
241
253
|
def __init__(self, pattern: Union[PatternNode, List], replacement: Replacement = None):
|
|
@@ -285,27 +297,14 @@ class PatternEngine:
|
|
|
285
297
|
is the matched node.
|
|
286
298
|
new_nodes (list[Node]): A list of instance of Node as replacement.
|
|
287
299
|
"""
|
|
288
|
-
to_erase_list = matched_dict.values()
|
|
289
|
-
# keep all old nodes' inputs
|
|
290
|
-
inputs_dict = {}
|
|
291
|
-
for node in to_erase_list:
|
|
292
|
-
inputs_dict[node.get_name()] = (node.get_inputs())
|
|
293
300
|
# call replace of SymbolTree
|
|
294
301
|
new_root = stree.replace(old_root, new_nodes)
|
|
295
302
|
# replace only support one-to-one replace or one-to-multi replace, we need to erase nodes except
|
|
296
303
|
# cur_node manually
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
if cur_node.get_users():
|
|
302
|
-
# if cur_node is depended on by other node, skip now.
|
|
303
|
-
# cur_node will be push into queue and be erased later
|
|
304
|
-
continue
|
|
305
|
-
if stree.get_node(cur_node.get_name()) is not None:
|
|
306
|
-
# cur_node is not erased before
|
|
307
|
-
stree.erase_node(cur_node)
|
|
308
|
-
queue.extend(inputs_dict.get(cur_node.get_name()))
|
|
304
|
+
to_erase_list = matched_dict.values()
|
|
305
|
+
for node in reversed(to_erase_list):
|
|
306
|
+
if node != old_root:
|
|
307
|
+
stree.erase(node)
|
|
309
308
|
return new_root
|
|
310
309
|
|
|
311
310
|
@staticmethod
|
|
@@ -316,7 +315,7 @@ class PatternEngine:
|
|
|
316
315
|
for n in reversed(to_erase_list):
|
|
317
316
|
if n.get_handler() is node:
|
|
318
317
|
continue
|
|
319
|
-
stree.
|
|
318
|
+
stree.erase(n)
|
|
320
319
|
|
|
321
320
|
def apply(self, stree: SymbolTree) -> bool:
|
|
322
321
|
"""
|
|
@@ -354,7 +353,7 @@ class PatternEngine:
|
|
|
354
353
|
# 2. Visit b, b does not match pattern, add a to queue.
|
|
355
354
|
# 3. Visit d, d does not match pattern, add c to queue.
|
|
356
355
|
# 4. Visit a, a matches pattern and erased from SymbolTree, add xx to queue.
|
|
357
|
-
# 5. Visit c,
|
|
356
|
+
# 5. Visit c, c does not match pattern, add a to queue.
|
|
358
357
|
# At step 5, a is visited at second time but a is erased from SymbolTree at step 4.
|
|
359
358
|
visited: [Node] = []
|
|
360
359
|
while queue:
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Rewrite module api: ValueType and ScopedValue."""
|
|
16
16
|
from enum import Enum
|
|
17
|
-
from typing import Optional, Union
|
|
17
|
+
from typing import Optional, Union, List, Tuple
|
|
18
18
|
from mindspore import _checkparam as Validator
|
|
19
19
|
|
|
20
20
|
|
|
@@ -28,9 +28,7 @@ class ValueType(Enum):
|
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
30
|
# base type
|
|
31
|
-
|
|
32
|
-
IntValue = 1
|
|
33
|
-
FloatValue = 2
|
|
31
|
+
ConstantValue = 0
|
|
34
32
|
# container type
|
|
35
33
|
TupleValue = 20
|
|
36
34
|
ListValue = 21
|
|
@@ -50,8 +48,9 @@ class ScopedValue:
|
|
|
50
48
|
Args:
|
|
51
49
|
arg_type (ValueType): A `ValueType` represents type of current value.
|
|
52
50
|
scope (str): A string represents scope of current value. Take "self.var1" as an example, `scope` of this
|
|
53
|
-
var1 is "self".
|
|
51
|
+
var1 is "self". Default: ``""`` .
|
|
54
52
|
value: A handler represents value of current value. The type of value is corresponding to `arg_type`.
|
|
53
|
+
Default: ``None`` .
|
|
55
54
|
"""
|
|
56
55
|
|
|
57
56
|
def __init__(self, arg_type: ValueType, scope: str = "", value=None):
|
|
@@ -77,13 +76,11 @@ class ScopedValue:
|
|
|
77
76
|
Examples:
|
|
78
77
|
>>> from mindspore.rewrite import ScopedValue
|
|
79
78
|
>>> variable = ScopedValue.create_variable_value(2)
|
|
79
|
+
>>> print(variable)
|
|
80
|
+
2
|
|
80
81
|
"""
|
|
81
|
-
if isinstance(value, int):
|
|
82
|
-
return cls(ValueType.
|
|
83
|
-
if isinstance(value, float):
|
|
84
|
-
return cls(ValueType.FloatValue, "", value)
|
|
85
|
-
if isinstance(value, str):
|
|
86
|
-
return cls(ValueType.StringValue, "", value)
|
|
82
|
+
if isinstance(value, (type(None), int, float, str, bool)):
|
|
83
|
+
return cls(ValueType.ConstantValue, "", value)
|
|
87
84
|
if isinstance(value, tuple):
|
|
88
85
|
return cls(ValueType.TupleValue, "",
|
|
89
86
|
tuple(cls.create_variable_value(single_value) for single_value in value))
|
|
@@ -104,7 +101,7 @@ class ScopedValue:
|
|
|
104
101
|
|
|
105
102
|
Args:
|
|
106
103
|
name: (str): A string represents the identifier of another variable.
|
|
107
|
-
scope: (str): A string represents the scope of another variable.
|
|
104
|
+
scope: (str): A string represents the scope of another variable. Default: ``""`` .
|
|
108
105
|
|
|
109
106
|
Returns:
|
|
110
107
|
An instance of `ScopedValue`.
|
|
@@ -116,19 +113,23 @@ class ScopedValue:
|
|
|
116
113
|
Examples:
|
|
117
114
|
>>> from mindspore.rewrite import ScopedValue
|
|
118
115
|
>>> variable = ScopedValue.create_naming_value("conv", "self")
|
|
116
|
+
>>> print(variable)
|
|
117
|
+
self.conv
|
|
119
118
|
"""
|
|
120
119
|
Validator.check_value_type("name", name, [str], "ScopedValue")
|
|
121
120
|
Validator.check_value_type("scope", scope, [str], "ScopedValue")
|
|
122
121
|
return cls(ValueType.NamingValue, scope, name)
|
|
123
122
|
|
|
124
123
|
@staticmethod
|
|
125
|
-
def create_name_values(names: Union[
|
|
124
|
+
def create_name_values(names: Union[List[str], Tuple[str]],
|
|
125
|
+
scopes: Union[List[str], Tuple[str]] = None) -> List['ScopedValue']:
|
|
126
126
|
"""
|
|
127
127
|
Create a list of naming `ScopedValue`.
|
|
128
128
|
|
|
129
129
|
Args:
|
|
130
|
-
names (
|
|
131
|
-
scopes (
|
|
130
|
+
names (List[str] or Tuple[str]): List or tuple of `str` represents names of referenced variables.
|
|
131
|
+
scopes (List[str] or Tuple[str]): List or tuple of `str` represents scopes of referenced variables.
|
|
132
|
+
Default: ``None`` .
|
|
132
133
|
|
|
133
134
|
Returns:
|
|
134
135
|
An list of instance of `ScopedValue`.
|
|
@@ -140,7 +141,9 @@ class ScopedValue:
|
|
|
140
141
|
|
|
141
142
|
Examples:
|
|
142
143
|
>>> from mindspore.rewrite import ScopedValue
|
|
143
|
-
>>> variables = ScopedValue.create_name_values(["z", "z_1"],
|
|
144
|
+
>>> variables = ScopedValue.create_name_values(names=["z", "z_1"], scopes=["self", "self"])
|
|
145
|
+
>>> print(variables)
|
|
146
|
+
[self.z, self.z_1]
|
|
144
147
|
"""
|
|
145
148
|
Validator.check_element_type_of_iterable("names", names, [str], "ScopedValue")
|
|
146
149
|
if scopes is not None:
|
|
@@ -157,7 +160,7 @@ class ScopedValue:
|
|
|
157
160
|
return result
|
|
158
161
|
|
|
159
162
|
def __str__(self):
|
|
160
|
-
if self.type
|
|
163
|
+
if self.type == ValueType.ConstantValue:
|
|
161
164
|
return str(self.value)
|
|
162
165
|
if self.type == ValueType.NamingValue:
|
|
163
166
|
return f"{self.scope}.{self.value}" if self.scope else str(self.value)
|