mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -325,9 +325,10 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
|
|
|
325
325
|
# If rank is larger than 1, we need to reduce result when reduction != 'none'
|
|
326
326
|
if max_rank > 1:
|
|
327
327
|
reduce_indexes = tuple(range(1, max_rank))
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
328
|
+
logits_dim_ok = logits_dim == label_dim and logits_dim == weight_dim and logits_dim == pos_weight_dim
|
|
329
|
+
shape = F.shape(logits)
|
|
330
|
+
shape_ok = shape == F.shape(label) and shape == F.shape(weight) and shape == F.shape(pos_weight)
|
|
331
|
+
if logits_dim_ok and shape_ok:
|
|
331
332
|
if prim_reduction == 'none':
|
|
332
333
|
output = prim(logits, label, weight, pos_weight)
|
|
333
334
|
elif prim_reduction in ('mean', 'sum'):
|
|
@@ -798,7 +799,8 @@ def get_instance_norm_rule(prim, axis_size):
|
|
|
798
799
|
output_x, updated_moving_mean, updated_moving_variance = prim(input_x, gamma, beta, mean, variance, u_monad)
|
|
799
800
|
return (output_x, None), (updated_moving_mean, None), (updated_moving_variance, None)
|
|
800
801
|
|
|
801
|
-
|
|
802
|
+
precondition = gamma_dim != 0 or beta_dim != gamma_dim or mean_dim != gamma_dim or variance_dim != gamma_dim
|
|
803
|
+
if precondition:
|
|
802
804
|
# pylint: disable=too-many-format-args
|
|
803
805
|
raise ValueError(
|
|
804
806
|
"For `{}`, the source axis of `var` must be equal to `accum` and `accum_update`, and not equal to 0, "
|
|
@@ -1309,6 +1311,61 @@ def get_apply_adam_with_amsgrad_rule(prim, axis_size):
|
|
|
1309
1311
|
return vmap_rule
|
|
1310
1312
|
|
|
1311
1313
|
|
|
1314
|
+
@vmap_rules_getters.register(P.ApplyAdamWithAmsgradV2)
|
|
1315
|
+
def get_apply_adam_with_amsgrad_v2_rule(prim, axis_size):
|
|
1316
|
+
"""VmapRule for `ApplyAdamWithAmsgradV2` operation"""
|
|
1317
|
+
if hasattr(prim, "batch_rank"):
|
|
1318
|
+
batch_rank = prim.batch_rank + 1
|
|
1319
|
+
else:
|
|
1320
|
+
batch_rank = 1
|
|
1321
|
+
prim_name = prim.name
|
|
1322
|
+
batch_prim = _vmap_clone_prim(prim)
|
|
1323
|
+
batch_prim.add_prim_attr("batch_rank", batch_rank)
|
|
1324
|
+
|
|
1325
|
+
def vmap_rule(var_bdim, m_bdim, v_bdim, vhat_bdim, beta1_power_bdim, beta2_power_bdim, lr_bdim, beta1_bdim,
|
|
1326
|
+
beta2_bdim, epsilon_bdim, grad_bdim, u_monad):
|
|
1327
|
+
var, var_dim = var_bdim
|
|
1328
|
+
m, m_dim = m_bdim
|
|
1329
|
+
v, v_dim = v_bdim
|
|
1330
|
+
vhat, vhat_dim = vhat_bdim
|
|
1331
|
+
beta1_power, beta1_power_dim = beta1_power_bdim
|
|
1332
|
+
beta2_power, beta2_power_dim = beta2_power_bdim
|
|
1333
|
+
lr, lr_dim = lr_bdim
|
|
1334
|
+
beta1, beta1_dim = beta1_bdim
|
|
1335
|
+
beta2, beta2_dim = beta2_bdim
|
|
1336
|
+
epsilon, epsilon_dim = epsilon_bdim
|
|
1337
|
+
grad, grad_dim = grad_bdim
|
|
1338
|
+
|
|
1339
|
+
if var_dim is None:
|
|
1340
|
+
if any(dim is not None for dim in [m_dim, v_dim, vhat_dim, beta1_power_dim,
|
|
1341
|
+
beta2_power_dim, lr_dim, beta1_dim, beta2_dim, grad_dim]):
|
|
1342
|
+
raise ValueError("The source axis of `var` is None, "
|
|
1343
|
+
"but the source axis of `m/v/vhat/beta1_power/beta2_power/lr/beta1/beta2/grad` is not "
|
|
1344
|
+
"None. The execution of operator `{}` cannot be guaranteed.".format(prim_name))
|
|
1345
|
+
out_var, out_m, out_v, out_vhat = prim(var, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2, epsilon,
|
|
1346
|
+
grad, u_monad)
|
|
1347
|
+
return (out_var, None), (out_m, None), (out_v, None), (out_vhat, None)
|
|
1348
|
+
|
|
1349
|
+
if any(dim != 0 for dim in [var_dim, m_dim, v_dim, vhat_dim]):
|
|
1350
|
+
raise ValueError("For `{}`, the source axis of `var/m/v/vhat` must be 0, "
|
|
1351
|
+
"but get `var`: {}, `m`: {}, `v`: {}, `vhat`: {}".format(prim_name, var_dim,
|
|
1352
|
+
m_dim, v_dim, vhat_dim))
|
|
1353
|
+
|
|
1354
|
+
beta1_power = _bdim_at_front(beta1_power, beta1_power_dim, axis_size)
|
|
1355
|
+
beta2_power = _bdim_at_front(beta2_power, beta2_power_dim, axis_size)
|
|
1356
|
+
lr = _bdim_at_front(lr, lr_dim, axis_size)
|
|
1357
|
+
beta1 = _bdim_at_front(beta1, beta1_dim, axis_size)
|
|
1358
|
+
beta2 = _bdim_at_front(beta2, beta2_dim, axis_size)
|
|
1359
|
+
epsilon = _bdim_at_front(epsilon, epsilon_dim, axis_size)
|
|
1360
|
+
grad = _bdim_at_front(grad, grad_dim, axis_size)
|
|
1361
|
+
|
|
1362
|
+
out_var, out_m, out_v, out_vhat = batch_prim(var, m, v, vhat, beta1_power, beta2_power, lr, beta1, beta2,
|
|
1363
|
+
epsilon, grad, u_monad)
|
|
1364
|
+
return (out_var, 0), (out_m, 0), (out_v, 0), (out_vhat, 0)
|
|
1365
|
+
|
|
1366
|
+
return vmap_rule
|
|
1367
|
+
|
|
1368
|
+
|
|
1312
1369
|
@vmap_rules_getters.register(P.Adam)
|
|
1313
1370
|
def get_adam_rule(prim, axis_size):
|
|
1314
1371
|
"""VmapRule for `Adam` operation"""
|
|
@@ -1624,7 +1681,8 @@ def get_rmsprop_vmap_rule(prim, axis_size):
|
|
|
1624
1681
|
res = prim(var, mean_square, moment, lr, grad, decay, momentum, epsilon,
|
|
1625
1682
|
u_monad) # low dimensional operator;
|
|
1626
1683
|
return (res, None)
|
|
1627
|
-
|
|
1684
|
+
precondition = var_dim != 0 or var_dim != mean_square_dim or var_dim != moment_dim or var_dim != grad_dim
|
|
1685
|
+
if precondition:
|
|
1628
1686
|
raise ValueError(
|
|
1629
1687
|
f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_square_dim' "
|
|
1630
1688
|
f"and 'moment_dim' and 'grad_dim' and not equal to 0, "
|
|
@@ -1680,8 +1738,8 @@ def get_apply_centered_rmsprop_vmap_rule(prim, axis_size):
|
|
|
1680
1738
|
var = prim(var, mean_grad, mean_square,
|
|
1681
1739
|
mom, grad, lr, rho, momentum, eps, u_monad)
|
|
1682
1740
|
return (var, None)
|
|
1683
|
-
|
|
1684
|
-
if
|
|
1741
|
+
precondition = var_dim != 0 or var_dim != mean_grad_dim or var_dim != mean_square_dim or var_dim != mom_dim
|
|
1742
|
+
if precondition:
|
|
1685
1743
|
raise ValueError(
|
|
1686
1744
|
f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_grad_dim' "
|
|
1687
1745
|
f"and 'mean_square_dim' and 'mom_dim' and not equal to 0, "
|
|
@@ -1748,6 +1806,7 @@ def get_max_pool_vmap_rule(prim, axis_size):
|
|
|
1748
1806
|
@vmap_rules_getters.register(P.LayerNorm)
|
|
1749
1807
|
def get_layernorm_vmap_rule(prim, axis_size):
|
|
1750
1808
|
"""VmapRule for `LayerNorm` operation."""
|
|
1809
|
+
|
|
1751
1810
|
@constexpr
|
|
1752
1811
|
def process_attr_axis(prim_attr_axis):
|
|
1753
1812
|
if prim_attr_axis < 0:
|
|
@@ -1794,6 +1853,7 @@ def get_layernorm_vmap_rule(prim, axis_size):
|
|
|
1794
1853
|
output = F.add(F.mul(output_tmp, g), b)
|
|
1795
1854
|
|
|
1796
1855
|
return (output, 0), (mean, 0), (var, 0)
|
|
1856
|
+
|
|
1797
1857
|
return vmap_rule
|
|
1798
1858
|
|
|
1799
1859
|
|
|
@@ -1828,6 +1888,7 @@ def get_grid_sampler_vmap_rule(prim, axis_size):
|
|
|
1828
1888
|
return_shape = input_x_shape[:non_batch_dim_index] + out_shape[non_batch_dim_index:]
|
|
1829
1889
|
out = F.reshape(out, return_shape)
|
|
1830
1890
|
return out, 0
|
|
1891
|
+
|
|
1831
1892
|
return vmap_rule
|
|
1832
1893
|
|
|
1833
1894
|
|
|
@@ -1837,21 +1898,31 @@ def get_upsample_nearest_3d_vmap_rule(prim, axis_size):
|
|
|
1837
1898
|
"""VmapRule for `UpsampleNearest3D` and `UpsampleTrilinear3D`."""
|
|
1838
1899
|
cdhw_reverse_index = -4
|
|
1839
1900
|
|
|
1840
|
-
def vmap_rule(x_bdim):
|
|
1841
|
-
is_all_none, result = vmap_general_preprocess(prim, x_bdim
|
|
1901
|
+
def vmap_rule(x_bdim, size_bdim, scales_bdim):
|
|
1902
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim, size_bdim,
|
|
1903
|
+
scales_bdim)
|
|
1842
1904
|
if is_all_none:
|
|
1843
1905
|
return result
|
|
1844
1906
|
|
|
1845
1907
|
x, x_dim = x_bdim
|
|
1846
1908
|
x = _bdim_at_front(x, x_dim, axis_size)
|
|
1909
|
+
size, size_dim = size_bdim
|
|
1910
|
+
scales, scales_dim = scales_bdim
|
|
1911
|
+
if size_dim is not None or scales_dim is not None:
|
|
1912
|
+
_raise_value_error(
|
|
1913
|
+
"The source axis of `output_size` and `scales` must be None, but got {0} and {1}."
|
|
1914
|
+
.format(size_dim, scales_dim))
|
|
1915
|
+
|
|
1847
1916
|
x_shape = F.shape(x)
|
|
1848
1917
|
input_shape = (-1,) + x_shape[cdhw_reverse_index:]
|
|
1849
1918
|
x = F.reshape(x, input_shape)
|
|
1850
|
-
out = prim(x)
|
|
1919
|
+
out = prim(x, size, scales)
|
|
1851
1920
|
out_shape = F.shape(out)
|
|
1852
|
-
return_shape = x_shape[:cdhw_reverse_index] + out_shape[
|
|
1921
|
+
return_shape = x_shape[:cdhw_reverse_index] + out_shape[
|
|
1922
|
+
cdhw_reverse_index:]
|
|
1853
1923
|
out = F.reshape(out, return_shape)
|
|
1854
1924
|
return out, 0
|
|
1925
|
+
|
|
1855
1926
|
return vmap_rule
|
|
1856
1927
|
|
|
1857
1928
|
|
|
@@ -1889,6 +1960,7 @@ def get_sparse_apply_adagrad_vmap_rule(prim, axis_size):
|
|
|
1889
1960
|
|
|
1890
1961
|
var, accum = batch_prim(var, accum, grad, indices, u_monad)
|
|
1891
1962
|
return (var, 0), (accum, 0)
|
|
1963
|
+
|
|
1892
1964
|
return vmap_rule
|
|
1893
1965
|
|
|
1894
1966
|
|
|
@@ -1927,6 +1999,58 @@ def get_sparse_apply_ftrl_vmap_rule(prim, axis_size):
|
|
|
1927
1999
|
|
|
1928
2000
|
var, accum, linear = batch_prim(var, accum, linear, grad, indices, u_monad)
|
|
1929
2001
|
return (var, 0), (accum, 0), (linear, 0)
|
|
2002
|
+
|
|
2003
|
+
return vmap_rule
|
|
2004
|
+
|
|
2005
|
+
|
|
2006
|
+
@vmap_rules_getters.register(P.Dense)
|
|
2007
|
+
def get_dense_vmap_rule(prim, axis_size):
|
|
2008
|
+
"""VmapRule for `Dense` operation."""
|
|
2009
|
+
if isinstance(prim, str):
|
|
2010
|
+
prim = Primitive(prim)
|
|
2011
|
+
|
|
2012
|
+
batch_matmul = P.BatchMatMul(transpose_b=True)
|
|
2013
|
+
|
|
2014
|
+
@_primexpr
|
|
2015
|
+
def get_start_mid_end(x_shape):
|
|
2016
|
+
start = x_shape[0]
|
|
2017
|
+
mid = 1
|
|
2018
|
+
for shp in x_shape[1:-1]:
|
|
2019
|
+
mid *= shp
|
|
2020
|
+
end = x_shape[-1]
|
|
2021
|
+
return start, mid, end
|
|
2022
|
+
|
|
2023
|
+
def vmap_rule(x_bdim, w_bdim, b_bdim):
|
|
2024
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim, w_bdim, b_bdim)
|
|
2025
|
+
if is_all_none:
|
|
2026
|
+
return result
|
|
2027
|
+
|
|
2028
|
+
x, x_dim = x_bdim
|
|
2029
|
+
w, w_dim = w_bdim
|
|
2030
|
+
b, b_dim = b_bdim
|
|
2031
|
+
x = _bdim_at_front(x, x_dim, axis_size)
|
|
2032
|
+
w = _bdim_at_front(w, w_dim, axis_size)
|
|
2033
|
+
if b is not None:
|
|
2034
|
+
b = _bdim_at_front(b, b_dim, axis_size)
|
|
2035
|
+
|
|
2036
|
+
x_shape = x.shape
|
|
2037
|
+
start, mid, end = get_start_mid_end(x_shape)
|
|
2038
|
+
|
|
2039
|
+
x = x.reshape(start, mid, end)
|
|
2040
|
+
|
|
2041
|
+
out = batch_matmul(x, w)
|
|
2042
|
+
out_shape = tuple(x_shape[:-1]) + (out.shape[-1],)
|
|
2043
|
+
out = out.reshape(out_shape)
|
|
2044
|
+
|
|
2045
|
+
if b is not None:
|
|
2046
|
+
b_shape = b.shape
|
|
2047
|
+
b_shape = (start,) + (1,) * (len(out_shape) - 2) + (b_shape[-1],)
|
|
2048
|
+
b = b.reshape(b_shape)
|
|
2049
|
+
|
|
2050
|
+
out = out + b
|
|
2051
|
+
|
|
2052
|
+
return out, 0
|
|
2053
|
+
|
|
1930
2054
|
return vmap_rule
|
|
1931
2055
|
|
|
1932
2056
|
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
|
+
#
|
|
3
|
+
# Copyright 2023-2024 Huawei Technologies Co., Ltd
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
# ============================================================================
|
|
17
|
+
"""Operator argument data type cast function."""
|
|
18
|
+
from enum import Enum
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TypeCastKind(Enum):
|
|
22
|
+
INT_TO_TUPLE = 1
|
|
23
|
+
INT_OR_TUPLE_TO_LIST = 2
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def type_it(src_data, cast_type):
|
|
27
|
+
"""
|
|
28
|
+
cast operator argument data type.
|
|
29
|
+
"""
|
|
30
|
+
if cast_type == TypeCastKind.INT_TO_TUPLE:
|
|
31
|
+
if isinstance(src_data, tuple):
|
|
32
|
+
return src_data
|
|
33
|
+
|
|
34
|
+
if isinstance(src_data, int):
|
|
35
|
+
return (src_data,)
|
|
36
|
+
|
|
37
|
+
raise TypeError(f'{src_data} is the wrong data type.')
|
|
38
|
+
|
|
39
|
+
if cast_type == TypeCastKind.INT_OR_TUPLE_TO_LIST:
|
|
40
|
+
if isinstance(src_data, list):
|
|
41
|
+
return src_data
|
|
42
|
+
|
|
43
|
+
if isinstance(src_data, int):
|
|
44
|
+
return [
|
|
45
|
+
src_data,
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
if isinstance(src_data, tuple):
|
|
49
|
+
dst_list = [item for item in src_data]
|
|
50
|
+
return dst_list
|
|
51
|
+
|
|
52
|
+
raise TypeError(f'{src_data} is the wrong data type.')
|
|
53
|
+
|
|
54
|
+
raise TypeError("Unsupported type cast")
|
|
@@ -28,7 +28,8 @@ from mindspore.ops.composite.multitype_ops.add_impl import hyper_add
|
|
|
28
28
|
from mindspore.ops.composite.multitype_ops.ones_like_impl import ones_like
|
|
29
29
|
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
30
30
|
from mindspore.ops.function.random_func import normal, laplace, uniform, gamma, poisson, multinomial
|
|
31
|
-
from mindspore.ops.composite.math_ops import
|
|
31
|
+
from mindspore.ops.composite.math_ops import matmul, cummin, mm
|
|
32
|
+
from mindspore.ops.function.math_func import count_nonzero, tensor_dot, vecdot, dot, batch_dot
|
|
32
33
|
from mindspore.ops.function.array_func import repeat_interleave, repeat_elements, sequence_mask
|
|
33
34
|
from mindspore.ops.function.vmap_func import _VmapGeneralPreprocess, _VmapGeneralRule
|
|
34
35
|
from mindspore.ops.function.clip_func import clip_by_value
|
|
@@ -52,15 +53,16 @@ __all__ = [
|
|
|
52
53
|
'gamma',
|
|
53
54
|
'poisson',
|
|
54
55
|
'multinomial',
|
|
55
|
-
'count_nonzero',
|
|
56
56
|
'cummin',
|
|
57
|
-
'tensor_dot',
|
|
58
|
-
'dot',
|
|
59
|
-
'batch_dot',
|
|
60
57
|
'repeat_elements',
|
|
61
58
|
'repeat_interleave',
|
|
62
59
|
'sequence_mask',
|
|
63
60
|
'matmul',
|
|
64
61
|
'mm',
|
|
62
|
+
'count_nonzero',
|
|
63
|
+
'tensor_dot',
|
|
64
|
+
'vecdot',
|
|
65
|
+
'dot',
|
|
66
|
+
'batch_dot',
|
|
65
67
|
'_Grad',
|
|
66
68
|
'_Vmap']
|
mindspore/ops/composite/base.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2020-
|
|
3
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,6 +20,7 @@ from __future__ import absolute_import
|
|
|
20
20
|
from functools import partial
|
|
21
21
|
|
|
22
22
|
from types import FunctionType, MethodType
|
|
23
|
+
import numpy as np
|
|
23
24
|
import mindspore as ms
|
|
24
25
|
from mindspore import context
|
|
25
26
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
@@ -28,14 +29,17 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|
|
28
29
|
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
|
29
30
|
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
|
|
30
31
|
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
|
|
31
|
-
ZerosLike_
|
|
32
|
+
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
|
|
33
|
+
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_
|
|
32
34
|
from mindspore.common import dtype as mstype
|
|
33
35
|
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
|
34
36
|
from mindspore.common.api import _add_flags, _core
|
|
35
37
|
from mindspore.ops.primitive import Primitive
|
|
36
38
|
from mindspore.ops import signature as sig
|
|
37
39
|
|
|
38
|
-
__all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_,
|
|
40
|
+
__all__ = [TupleAdd_, ListAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_,
|
|
41
|
+
ListSliceSetItem_, ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_,
|
|
42
|
+
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_]
|
|
39
43
|
|
|
40
44
|
|
|
41
45
|
def add_flags(fn=None, **flags):
|
|
@@ -46,8 +50,8 @@ def add_flags(fn=None, **flags):
|
|
|
46
50
|
Only supports bool value.
|
|
47
51
|
|
|
48
52
|
Args:
|
|
49
|
-
fn (Function): Function or cell to add flag. Default: None.
|
|
50
|
-
flags (dict): Flags use kwargs. Default: None.
|
|
53
|
+
fn (Function): Function or cell to add flag. Default: ``None`` .
|
|
54
|
+
flags (dict): Flags use kwargs. Default: ``None`` .
|
|
51
55
|
|
|
52
56
|
Returns:
|
|
53
57
|
Function, the function with added flags.
|
|
@@ -70,9 +74,9 @@ def core(fn=None, **flags):
|
|
|
70
74
|
set flag to a graph.
|
|
71
75
|
|
|
72
76
|
Args:
|
|
73
|
-
fn (Function, optional): Function to add flag. Default: None.
|
|
77
|
+
fn (Function, optional): Function to add flag. Default: ``None`` .
|
|
74
78
|
flags (dict, optional): The following flags can be set core, which indicates that this is a core function or
|
|
75
|
-
other flag. Default: None.
|
|
79
|
+
other flag. Default: ``None`` .
|
|
76
80
|
|
|
77
81
|
Supported Platforms:
|
|
78
82
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -187,16 +191,16 @@ class GradOperation(GradOperation_):
|
|
|
187
191
|
- Return an empty tuple for no result.
|
|
188
192
|
|
|
189
193
|
Args:
|
|
190
|
-
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
|
|
191
|
-
get_by_list (bool): If True, get all the gradients with respect to Parameter free variables.
|
|
192
|
-
If get_all and get_by_list are both False, get the gradient with respect to first input.
|
|
193
|
-
If get_all and get_by_list are both True, get the gradients with respect to inputs and
|
|
194
|
+
get_all (bool): If ``True`` , get all the gradients with respect to inputs. Default: ``False`` .
|
|
195
|
+
get_by_list (bool): If ``True`` , get all the gradients with respect to Parameter free variables.
|
|
196
|
+
If get_all and get_by_list are both ``False`` , get the gradient with respect to first input.
|
|
197
|
+
If get_all and get_by_list are both ``True`` , get the gradients with respect to inputs and
|
|
194
198
|
Parameter free variables at the same time in the form of ("gradients with respect to inputs",
|
|
195
|
-
"gradients with respect to parameter free variables"). Default: False.
|
|
199
|
+
"gradients with respect to parameter free variables"). Default: ``False`` .
|
|
196
200
|
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
|
197
|
-
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
|
198
|
-
Default: False.
|
|
199
|
-
If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred
|
|
201
|
+
If sens_param is ``False`` , a 'ones_like(outputs)' sensitivity will be attached automatically.
|
|
202
|
+
Default: ``False`` .
|
|
203
|
+
If the sensor_param is ``True`` , a sensitivity (gradient with respect to output) needs to be transferred
|
|
200
204
|
through the location parameter or key-value pair parameter. If the value is transferred through
|
|
201
205
|
the key-value pair parameter, the key must be sens.
|
|
202
206
|
|
|
@@ -210,6 +214,10 @@ class GradOperation(GradOperation_):
|
|
|
210
214
|
``Ascend`` ``GPU`` ``CPU``
|
|
211
215
|
|
|
212
216
|
Examples:
|
|
217
|
+
>>> import mindspore
|
|
218
|
+
>>> import numpy as np
|
|
219
|
+
>>> from mindspore import dtype as mstype
|
|
220
|
+
>>> from mindspore import Tensor, ops, nn, Parameter
|
|
213
221
|
>>> class Net(nn.Cell):
|
|
214
222
|
... def __init__(self):
|
|
215
223
|
... super(Net, self).__init__()
|
|
@@ -329,7 +337,7 @@ class GradOperation(GradOperation_):
|
|
|
329
337
|
self.get_all = get_all
|
|
330
338
|
self.get_by_list = get_by_list
|
|
331
339
|
self.sens_param = sens_param
|
|
332
|
-
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False, False)
|
|
340
|
+
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False, False, False)
|
|
333
341
|
self.grad_fn = None
|
|
334
342
|
self.fn = None
|
|
335
343
|
self.weights_id = None
|
|
@@ -360,6 +368,9 @@ class GradOperation(GradOperation_):
|
|
|
360
368
|
def after_grad(*args, **kwargs):
|
|
361
369
|
return grad_(fn)(*args, **kwargs)
|
|
362
370
|
elif self.pynative_:
|
|
371
|
+
if not _pynative_executor.enable_grad():
|
|
372
|
+
raise RuntimeError("In no_grad context, you can not calculate gradient")
|
|
373
|
+
|
|
363
374
|
@_wrap_func
|
|
364
375
|
def after_grad(*args, **kwargs):
|
|
365
376
|
self._pynative_forward_run(fn, grad_, weights, args, kwargs)
|
|
@@ -369,6 +380,8 @@ class GradOperation(GradOperation_):
|
|
|
369
380
|
return out
|
|
370
381
|
else:
|
|
371
382
|
grad_.pynative_ = True
|
|
383
|
+
if not _pynative_executor.enable_grad():
|
|
384
|
+
raise RuntimeError("In no_grad context, you can not calculate gradient")
|
|
372
385
|
# after_grad of this branch can't use @jit, just directly call grad_
|
|
373
386
|
if self.get_by_list:
|
|
374
387
|
def after_grad(*args, **kwargs):
|
|
@@ -501,8 +514,8 @@ class _Grad(GradOperation_):
|
|
|
501
514
|
A higher-order function which is used to generate the gradient function by position for the input function.
|
|
502
515
|
"""
|
|
503
516
|
|
|
504
|
-
def __init__(self, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False,
|
|
505
|
-
return_ids=False):
|
|
517
|
+
def __init__(self, get_all=False, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False,
|
|
518
|
+
get_value=False, return_ids=False, merge_forward=False):
|
|
506
519
|
"""Initialize _Grad."""
|
|
507
520
|
if not isinstance(get_by_position, bool):
|
|
508
521
|
raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, "
|
|
@@ -522,14 +535,16 @@ class _Grad(GradOperation_):
|
|
|
522
535
|
if not isinstance(return_ids, bool):
|
|
523
536
|
raise TypeError(f"For '_Grad', the 'return_ids' should be bool, "
|
|
524
537
|
f"but got {type(return_ids).__name__}")
|
|
538
|
+
self.get_all = get_all
|
|
525
539
|
self.get_by_position = get_by_position
|
|
526
540
|
self.get_by_list = get_by_list
|
|
527
541
|
self.sens_param = sens_param
|
|
528
542
|
self.has_aux = has_aux
|
|
529
543
|
self.get_value = get_value
|
|
530
544
|
self.return_ids = return_ids
|
|
531
|
-
|
|
532
|
-
|
|
545
|
+
self.merge_forward = merge_forward
|
|
546
|
+
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, get_by_position, has_aux, get_value,
|
|
547
|
+
return_ids, merge_forward)
|
|
533
548
|
self.grad_fn = None
|
|
534
549
|
self.fn = None
|
|
535
550
|
self.pynative_ = False
|
|
@@ -552,8 +567,8 @@ class _Grad(GradOperation_):
|
|
|
552
567
|
res += (stop_gradient(item),)
|
|
553
568
|
return res
|
|
554
569
|
|
|
555
|
-
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux,
|
|
556
|
-
self.return_ids)
|
|
570
|
+
grad_ = _Grad(self.get_all, self.get_by_list, self.sens_param, self.get_by_position, self.has_aux,
|
|
571
|
+
self.get_value, self.return_ids, self.merge_forward)
|
|
557
572
|
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
|
558
573
|
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
|
|
559
574
|
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
|
@@ -577,6 +592,9 @@ class _Grad(GradOperation_):
|
|
|
577
592
|
def after_grad(*args):
|
|
578
593
|
return grad_(fn)(*args)
|
|
579
594
|
elif self.pynative_:
|
|
595
|
+
if not _pynative_executor.enable_grad():
|
|
596
|
+
raise RuntimeError("In no_grad context, you can not calculate gradient")
|
|
597
|
+
|
|
580
598
|
@_wrap_func
|
|
581
599
|
def after_grad(*args, **kwargs):
|
|
582
600
|
res = self._pynative_forward_run(fn, grad_, weights, args, kwargs)
|
|
@@ -591,6 +609,8 @@ class _Grad(GradOperation_):
|
|
|
591
609
|
return out, res[1:]
|
|
592
610
|
return out
|
|
593
611
|
else:
|
|
612
|
+
if not _pynative_executor.enable_grad():
|
|
613
|
+
raise RuntimeError("In no_grad context, you can not calculate gradient")
|
|
594
614
|
grad_.pynative_ = True
|
|
595
615
|
fn_ = fn
|
|
596
616
|
if self.has_aux:
|
|
@@ -682,8 +702,7 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
682
702
|
Args:
|
|
683
703
|
name (str): Operator name.
|
|
684
704
|
read_value (bool, optional): If the registered function do not need to set value on Parameter,
|
|
685
|
-
and all inputs will pass by value, set `read_value` to True. Default: False.
|
|
686
|
-
doc_url (str, optional): The official document link corresponding to the registered function. Default:"".
|
|
705
|
+
and all inputs will pass by value, set `read_value` to ``True`` . Default: ``False`` .
|
|
687
706
|
|
|
688
707
|
Raises:
|
|
689
708
|
ValueError: If failed to find a matching function for the given arguments.
|
|
@@ -715,15 +734,18 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
715
734
|
[0.2 1.2 2.4]
|
|
716
735
|
"""
|
|
717
736
|
|
|
718
|
-
def __init__(self, name, read_value=False
|
|
737
|
+
def __init__(self, name, read_value=False):
|
|
719
738
|
"""Initialize MultitypeFuncGraph."""
|
|
720
|
-
MultitypeFuncGraph_.__init__(self, name
|
|
739
|
+
MultitypeFuncGraph_.__init__(self, name)
|
|
721
740
|
self.entries = list()
|
|
722
741
|
if read_value:
|
|
723
742
|
self.set_signatures((
|
|
724
743
|
sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
|
|
725
744
|
|
|
726
745
|
def __call__(self, *args):
|
|
746
|
+
for arg in args:
|
|
747
|
+
if isinstance(arg, np.ndarray):
|
|
748
|
+
raise TypeError("For 'MultitypeFuncGraph', the input can not be numpy.ndarray")
|
|
727
749
|
if len(self.entries) == 1:
|
|
728
750
|
output = self.entries[0][1](*args)
|
|
729
751
|
return output
|
|
@@ -766,6 +788,13 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
766
788
|
|
|
767
789
|
return deco
|
|
768
790
|
|
|
791
|
+
# pylint: disable=missing-docstring
|
|
792
|
+
def set_doc_url(self, doc_url):
|
|
793
|
+
self.set_doc_url_(doc_url)
|
|
794
|
+
|
|
795
|
+
def set_need_raise(self):
|
|
796
|
+
self.set_need_raise_()
|
|
797
|
+
|
|
769
798
|
|
|
770
799
|
class HyperMap(HyperMap_):
|
|
771
800
|
"""
|
|
@@ -856,10 +885,10 @@ class Map(Map_):
|
|
|
856
885
|
|
|
857
886
|
Args:
|
|
858
887
|
ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
|
|
859
|
-
the operations should be put in the first input of the instance. Default: None
|
|
888
|
+
the operations should be put in the first input of the instance. Default: ``None`` .
|
|
860
889
|
reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
|
|
861
890
|
general users please ignore. `Reverse` is the flag to decide if apply the operation reversely.
|
|
862
|
-
Only supported in graph mode. Default is False.
|
|
891
|
+
Only supported in graph mode. Default is ``False`` .
|
|
863
892
|
|
|
864
893
|
Inputs:
|
|
865
894
|
- **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
|
|
@@ -869,7 +898,7 @@ class Map(Map_):
|
|
|
869
898
|
If `ops` is `None`, the first input is the operation, and the other is inputs.
|
|
870
899
|
|
|
871
900
|
Outputs:
|
|
872
|
-
Sequence, the sequence of output after applying the function. e.g. `
|
|
901
|
+
Sequence, the sequence of output after applying the ops function. e.g. `ops(args[0][i], args[1][i])`.
|
|
873
902
|
|
|
874
903
|
Supported Platforms:
|
|
875
904
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -922,11 +951,7 @@ class _ListAppend(ListAppend_):
|
|
|
922
951
|
Args:
|
|
923
952
|
name (str): The name of the metafuncgraph object.
|
|
924
953
|
"""
|
|
925
|
-
|
|
926
|
-
def __init__(self, name):
|
|
927
|
-
"""Initialize _ListAppend."""
|
|
928
|
-
ListAppend_.__init__(self, name)
|
|
929
|
-
|
|
954
|
+
# `__init__` method removed entirely
|
|
930
955
|
def __call__(self, *args):
|
|
931
956
|
pass
|
|
932
957
|
|
|
@@ -1029,6 +1054,25 @@ class _ListExtend(ListExtend_):
|
|
|
1029
1054
|
_extend = _ListExtend("extend")
|
|
1030
1055
|
|
|
1031
1056
|
|
|
1057
|
+
class _DictSetItem(DictSetItem_):
|
|
1058
|
+
"""
|
|
1059
|
+
A metafuncgraph class that setitem for the dict.
|
|
1060
|
+
|
|
1061
|
+
Args:
|
|
1062
|
+
name (str): The name of the metafuncgraph object.
|
|
1063
|
+
"""
|
|
1064
|
+
|
|
1065
|
+
def __init__(self, name):
|
|
1066
|
+
"""Initialize _DictClear."""
|
|
1067
|
+
DictSetItem_.__init__(self, name)
|
|
1068
|
+
|
|
1069
|
+
def __call__(self, *args):
|
|
1070
|
+
pass
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
_dict_setitem = _DictSetItem("setitem")
|
|
1074
|
+
|
|
1075
|
+
|
|
1032
1076
|
class _DictClear(DictClear_):
|
|
1033
1077
|
"""
|
|
1034
1078
|
A metafuncgraph class that clear the dict.
|