mindspore 2.1.0__cp39-none-any.whl → 2.2.11__cp39-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +477 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -390,7 +390,7 @@ class Conv2DBackpropFilter(Primitive):
|
|
|
390
390
|
stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
|
|
391
391
|
dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
|
|
392
392
|
group (int): Splits input into groups. Default: 1.
|
|
393
|
-
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW'
|
|
393
|
+
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW', \
|
|
394
394
|
default is 'NCHW'.
|
|
395
395
|
|
|
396
396
|
Returns:
|
|
@@ -636,7 +636,7 @@ class EinsumGrad(PrimitiveWithInfer):
|
|
|
636
636
|
|
|
637
637
|
@prim_attr_register
|
|
638
638
|
def __init__(self, equation):
|
|
639
|
-
|
|
639
|
+
pass
|
|
640
640
|
|
|
641
641
|
def infer_shape(self, x_shapes, dout_shape):
|
|
642
642
|
out_shape = ()
|
|
@@ -1521,9 +1521,11 @@ class LSTMGrad(Primitive):
|
|
|
1521
1521
|
"""Computes the data and weight gradients of LSTM."""
|
|
1522
1522
|
|
|
1523
1523
|
@prim_attr_register
|
|
1524
|
-
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
|
|
1524
|
+
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout, proj_size=0):
|
|
1525
1525
|
self.input_size = validator.check_positive_int(input_size, 'input_size', self.name)
|
|
1526
1526
|
self.hidden_size = validator.check_positive_int(hidden_size, 'hidden_size', self.name)
|
|
1527
|
+
self.proj_size = validator.check_int_range(proj_size, 0, hidden_size, validator.INC_LEFT,
|
|
1528
|
+
'proj_size', self.name)
|
|
1527
1529
|
self.num_layers = validator.check_positive_int(num_layers, 'num_layers', self.name)
|
|
1528
1530
|
self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name)
|
|
1529
1531
|
self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name)
|
|
@@ -2573,7 +2575,12 @@ class MultilabelMarginLossGrad(Primitive):
|
|
|
2573
2575
|
Compute the gradients of MultilabelMarginLoss operation.
|
|
2574
2576
|
|
|
2575
2577
|
Args:
|
|
2576
|
-
reduction (str): Apply specific reduction method to the output: 'none', 'mean',
|
|
2578
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
2579
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
2580
|
+
|
|
2581
|
+
- ``'none'``: no reduction will be applied.
|
|
2582
|
+
- ``'mean'``: compute and return the mean of elements in the output.
|
|
2583
|
+
- ``'sum'``: the output elements will be summed.
|
|
2577
2584
|
|
|
2578
2585
|
Inputs:
|
|
2579
2586
|
- **y_grad** (Tensor) - The gradients of loss to output of MultilabelMarginLoss function, with
|
|
@@ -2595,7 +2602,7 @@ class MultilabelMarginLossGrad(Primitive):
|
|
|
2595
2602
|
TypeError: If dtype of `y_grad` is not the same as `x`.
|
|
2596
2603
|
ValueError: If length of shape of `x` is neither 1 nor 2.
|
|
2597
2604
|
ValueError: If shape of `x` is not the same as `target`.
|
|
2598
|
-
ValueError: If `reduction` is not one of 'none'
|
|
2605
|
+
ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
|
|
2599
2606
|
ValueError: If shape of `y_grad` is not the same as forward output `y`.
|
|
2600
2607
|
|
|
2601
2608
|
Supported Platforms:
|
|
@@ -2862,7 +2869,9 @@ class Dilation2DBackpropFilter(Primitive):
|
|
|
2862
2869
|
self.pad_mode = validator.check_string(self.pad_mode, ["SAME", "VALID", 'same', "valid"], "pad_mode", self.name)
|
|
2863
2870
|
self.add_prim_attr("pad_mode", self.pad_mode.upper())
|
|
2864
2871
|
self.stride = _check_format_stride_or_dilation("stride", stride, self.name, self.data_format)
|
|
2865
|
-
|
|
2872
|
+
def is_in_range(x):
|
|
2873
|
+
return 1 <= x <= 255
|
|
2874
|
+
if not is_in_range(self.stride[2]) or not is_in_range(self.stride[3]):
|
|
2866
2875
|
raise ValueError(f"For '{self.name}', size of stride is not supported, "
|
|
2867
2876
|
f'stride should be in the range of [1, 255], '
|
|
2868
2877
|
f'but got stride_h: `{self.stride[2]}`, stride_w: `{self.stride[3]}`.')
|
|
@@ -2917,7 +2926,12 @@ class MultiMarginLossGrad(Primitive):
|
|
|
2917
2926
|
Args:
|
|
2918
2927
|
p (int): Optional. The norm degree for pairwise distance.Should be 1 or 2. Default: 1.
|
|
2919
2928
|
margin (float): Optional. A parameter to change pairwise distance. Default: 1.0.
|
|
2920
|
-
reduction (str): Apply specific reduction method to the output: 'none', 'mean',
|
|
2929
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
2930
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
2931
|
+
|
|
2932
|
+
- ``'none'``: no reduction will be applied.
|
|
2933
|
+
- ``'mean'``: compute and return the weighted mean of elements in the output.
|
|
2934
|
+
- ``'sum'``: the output elements will be summed.
|
|
2921
2935
|
|
|
2922
2936
|
Inputs:
|
|
2923
2937
|
- **y_grad** (Tensor) - If it's not a scalar, the shape of 'y_grad' :math:`(N, C)`.
|
|
@@ -3818,3 +3832,53 @@ class WKVGrad(Primitive):
|
|
|
3818
3832
|
"""Initialize WKVGrad."""
|
|
3819
3833
|
self.init_prim_io_names(inputs=["time_first", "time_decay", "key", "value", "gy"],
|
|
3820
3834
|
outputs=["gw", "gu", "gk", "gv"])
|
|
3835
|
+
|
|
3836
|
+
|
|
3837
|
+
class FlashAttentionScoreGrad(Primitive):
|
|
3838
|
+
r"""
|
|
3839
|
+
Calculates the gradient of FlashAttentionScore operation.
|
|
3840
|
+
.. warning::
|
|
3841
|
+
This is an experimental API that is subject to change or deletion.
|
|
3842
|
+
|
|
3843
|
+
Supported Platforms:
|
|
3844
|
+
``Ascend``
|
|
3845
|
+
"""
|
|
3846
|
+
@prim_attr_register
|
|
3847
|
+
def __init__(self, head_num, keep_prob=1.0, scale_value=1.0, pre_tokens=65536, next_tokens=65536, inner_precise=1,
|
|
3848
|
+
input_layout='BSH', sparse_mode=0):
|
|
3849
|
+
"""Initialize FlashAttentionScoreGrad."""
|
|
3850
|
+
validator.check_value_type('head_num', head_num, [int], self.name)
|
|
3851
|
+
validator.check_value_type('keep_prob', keep_prob, [int, float], self.name)
|
|
3852
|
+
validator.check_float(keep_prob, 0.0, validator.GE, "keep_prob", self.name)
|
|
3853
|
+
validator.check_float(keep_prob, 1.0, validator.LE, "keep_prob", self.name)
|
|
3854
|
+
validator.check_value_type('scale_value', scale_value, [float], self.name)
|
|
3855
|
+
validator.check_value_type('pre_tokens', pre_tokens, [int], self.name)
|
|
3856
|
+
validator.check_value_type('next_tokens', next_tokens, [int], self.name)
|
|
3857
|
+
validator.check_value_type('inner_precise', inner_precise, [int], self.name)
|
|
3858
|
+
validator.check_value_type('sparse_mode', sparse_mode, [int], self.name)
|
|
3859
|
+
if inner_precise not in [0, 1]:
|
|
3860
|
+
raise ValueError(f"Attribute 'inner_precise' must be either 0 or 1, but got {inner_precise}")
|
|
3861
|
+
validator.check_value_type('input_layout', input_layout, [str], self.name)
|
|
3862
|
+
if input_layout not in ["BSH", "BNSD"]:
|
|
3863
|
+
raise ValueError(f"Attribute 'input_layout' must be either 'BSH' or 'BNSD', but got {input_layout}")
|
|
3864
|
+
self.init_prim_io_names(inputs=['query', 'key', 'value', 'dy', 'pse_shift', 'drop_mask', "padding_mask",
|
|
3865
|
+
'attn_mask', 'softmax_max', 'softmax_sum', 'softmax_out', 'attention_in',
|
|
3866
|
+
'prefix'],
|
|
3867
|
+
outputs=['dq', 'dk', 'dv', 'dpse'])
|
|
3868
|
+
|
|
3869
|
+
|
|
3870
|
+
class RmsNormGrad(Primitive):
|
|
3871
|
+
r"""
|
|
3872
|
+
Calculates the gradient of RmsNorm operation.
|
|
3873
|
+
.. warning::
|
|
3874
|
+
This is an experimental API that is subject to change or deletion.
|
|
3875
|
+
|
|
3876
|
+
Supported Platforms:
|
|
3877
|
+
``Ascend``
|
|
3878
|
+
"""
|
|
3879
|
+
|
|
3880
|
+
@prim_attr_register
|
|
3881
|
+
def __init__(self):
|
|
3882
|
+
"""Initialize RmsNormGrad."""
|
|
3883
|
+
self.init_prim_io_names(inputs=["dy", "x", "rstd", "gamma"],
|
|
3884
|
+
outputs=["dx", "dgamma"])
|
|
@@ -23,16 +23,17 @@ from mindspore.common._stub_tensor import StubTensor
|
|
|
23
23
|
from mindspore.ops import composite as C
|
|
24
24
|
from mindspore.ops.operations.array_ops import Cast
|
|
25
25
|
from mindspore.ops.operations._scalar_ops import bit_or, bit_and
|
|
26
|
+
from mindspore.ops.operations.comm_ops import ReduceOp
|
|
26
27
|
from mindspore.ops import signature as sig
|
|
27
28
|
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
28
|
-
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive,
|
|
29
|
-
|
|
29
|
+
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
|
|
30
|
+
_run_op, _check_contains_variable
|
|
30
31
|
from mindspore._c_expression import Tensor as Tensor_
|
|
31
32
|
from mindspore._c_expression import typing
|
|
32
33
|
from mindspore import _checkparam as validator
|
|
33
34
|
from mindspore.common import dtype as mstype
|
|
34
35
|
from mindspore.common.parameter import Parameter
|
|
35
|
-
from mindspore.communication.management import GlobalComm
|
|
36
|
+
from mindspore.communication.management import GlobalComm, get_rank
|
|
36
37
|
from mindspore.common.api import _pynative_executor
|
|
37
38
|
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
38
39
|
from mindspore import ops
|
|
@@ -74,11 +75,11 @@ class ExtractImagePatches(Primitive):
|
|
|
74
75
|
- valid: Means that the taken patch area must be completely covered in the original image.
|
|
75
76
|
|
|
76
77
|
Inputs:
|
|
77
|
-
- **input_x** (Tensor) - A 4-D tensor whose shape is :math:`(
|
|
78
|
+
- **input_x** (Tensor) - A 4-D tensor whose shape is :math:`(in\_batch, in\_depth, in\_row, in\_col)`.
|
|
78
79
|
|
|
79
80
|
Outputs:
|
|
80
81
|
Tensor, a 4-D tensor whose data type is same as 'input_x', and the shape
|
|
81
|
-
is :math:`(
|
|
82
|
+
is :math:`(out\_batch, out\_depth, out\_row, out\_col)`,where the out_batch is the same as the in_batch
|
|
82
83
|
and
|
|
83
84
|
|
|
84
85
|
.. math::
|
|
@@ -121,7 +122,6 @@ class ExtractImagePatches(Primitive):
|
|
|
121
122
|
validator.check_value_type('padding', padding, [str], self.name)
|
|
122
123
|
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
|
|
123
124
|
self.add_prim_attr("padding", self.padding)
|
|
124
|
-
self.is_ge = context.get_context("enable_ge")
|
|
125
125
|
|
|
126
126
|
|
|
127
127
|
class Quant(PrimitiveWithInfer):
|
|
@@ -167,6 +167,7 @@ class Quant(PrimitiveWithInfer):
|
|
|
167
167
|
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
|
168
168
|
self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
|
|
169
169
|
"round_mode", self.name)
|
|
170
|
+
self.add_prim_attr("dst_type", mstype.int8)
|
|
170
171
|
|
|
171
172
|
def infer_shape(self, x_shape):
|
|
172
173
|
return x_shape
|
|
@@ -174,7 +175,7 @@ class Quant(PrimitiveWithInfer):
|
|
|
174
175
|
def infer_dtype(self, x_type):
|
|
175
176
|
validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
|
|
176
177
|
validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
|
|
177
|
-
return
|
|
178
|
+
return self.get_attr_dict()['dst_type']
|
|
178
179
|
|
|
179
180
|
|
|
180
181
|
class Lamb(PrimitiveWithInfer):
|
|
@@ -491,7 +492,7 @@ class Receive(PrimitiveWithInfer):
|
|
|
491
492
|
self.dtype = dtype
|
|
492
493
|
self.group = group
|
|
493
494
|
self.add_prim_attr("no_eliminate", True)
|
|
494
|
-
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
495
|
+
valid_type = [mstype.float16, mstype.bfloat16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
495
496
|
args = {"dtype": dtype}
|
|
496
497
|
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
|
|
497
498
|
|
|
@@ -502,6 +503,109 @@ class Receive(PrimitiveWithInfer):
|
|
|
502
503
|
return self.get_attr_dict()['dtype']
|
|
503
504
|
|
|
504
505
|
|
|
506
|
+
class Reduce(PrimitiveWithInfer):
|
|
507
|
+
"""
|
|
508
|
+
Reduces tensor across the processes in the specified communication group.
|
|
509
|
+
|
|
510
|
+
Note:
|
|
511
|
+
Only process with destination rank receives the reduced output.
|
|
512
|
+
Other processes only get a tensor with shape [1], which has no mathematical meaning.
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
dest_rank (int): Specifies the rank of the process that receives the reduced output.
|
|
516
|
+
op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
|
|
517
|
+
On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
|
|
518
|
+
group (str, optional): The communication group to work on.
|
|
519
|
+
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
520
|
+
|
|
521
|
+
Inputs:
|
|
522
|
+
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
523
|
+
|
|
524
|
+
Examples:
|
|
525
|
+
>>> import mindspore.ops as ops
|
|
526
|
+
>>> import mindspore.nn as nn
|
|
527
|
+
>>> from mindspore.communication import init
|
|
528
|
+
>>> from mindspore import Tensor
|
|
529
|
+
>>> import numpy as np
|
|
530
|
+
>>> # Launch 4 processes.
|
|
531
|
+
>>> init()
|
|
532
|
+
>>> class ReduceNet(nn.Cell):
|
|
533
|
+
>>> def __init__(self):
|
|
534
|
+
>>> super(Net, self).__init__()
|
|
535
|
+
>>> self.reduce = ops.Reduce(dest_rank=1)
|
|
536
|
+
>>>
|
|
537
|
+
>>> def construct(self, x):
|
|
538
|
+
>>> out = self.reduce(x)
|
|
539
|
+
>>> return out
|
|
540
|
+
>>> input = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
541
|
+
>>> net = ReduceNet()
|
|
542
|
+
>>> output = net(input)
|
|
543
|
+
>>> print(output)
|
|
544
|
+
Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
|
|
545
|
+
[4. 4. 4. 4. 4. 4. 4. 4.]],
|
|
546
|
+
Other proesses: [0.].
|
|
547
|
+
"""
|
|
548
|
+
|
|
549
|
+
@prim_attr_register
|
|
550
|
+
def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
|
551
|
+
self.dest_rank = dest_rank
|
|
552
|
+
self.op = op
|
|
553
|
+
self.group = group
|
|
554
|
+
|
|
555
|
+
def infer_shape(self, x_shape):
|
|
556
|
+
# The process with dest_rank returns the reduced output.
|
|
557
|
+
# Other processes only gets a tensor with shape [1], which has no mathematical meaning.
|
|
558
|
+
if self.dest_rank == get_rank():
|
|
559
|
+
return x_shape
|
|
560
|
+
return [1]
|
|
561
|
+
|
|
562
|
+
def infer_dtype(self, x_dtype):
|
|
563
|
+
return x_dtype
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
class Barrier(PrimitiveWithInfer):
|
|
567
|
+
"""
|
|
568
|
+
Synchronizes all processes in the specified group.
|
|
569
|
+
|
|
570
|
+
Note:
|
|
571
|
+
After calling this collective operator,
|
|
572
|
+
this process will be blocked until all other processes in the group call this operator.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
group (str, optional): The communication group to work on.
|
|
576
|
+
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
577
|
+
|
|
578
|
+
Examples:
|
|
579
|
+
>>> import mindspore.ops as ops
|
|
580
|
+
>>> import mindspore.nn as nn
|
|
581
|
+
>>> from mindspore.communication import init
|
|
582
|
+
>>> from mindspore import Tensor
|
|
583
|
+
>>> import numpy as np
|
|
584
|
+
>>> # Launch 4 processes.
|
|
585
|
+
>>> init()
|
|
586
|
+
>>> class BarrierNet(nn.Cell):
|
|
587
|
+
>>> def __init__(self):
|
|
588
|
+
>>> super(Net, self).__init__()
|
|
589
|
+
>>> self.barrier = ops.Barrier()
|
|
590
|
+
>>>
|
|
591
|
+
>>> def construct(self):
|
|
592
|
+
>>> self.barrier()
|
|
593
|
+
>>> net = BarrierNet()
|
|
594
|
+
>>> net()
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
@prim_attr_register
|
|
598
|
+
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
|
599
|
+
self.group = group
|
|
600
|
+
self.add_prim_attr("side_effect_mem", True)
|
|
601
|
+
|
|
602
|
+
def infer_shape(self):
|
|
603
|
+
return [1]
|
|
604
|
+
|
|
605
|
+
def infer_dtype(self):
|
|
606
|
+
return mstype.float32
|
|
607
|
+
|
|
608
|
+
|
|
505
609
|
class MatrixSetDiag(PrimitiveWithInfer):
|
|
506
610
|
r"""
|
|
507
611
|
Modifies the batched diagonal part of a batched tensor.
|
|
@@ -1843,16 +1947,32 @@ class Format(PrimitiveWithInfer):
|
|
|
1843
1947
|
def __init__(self):
|
|
1844
1948
|
self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
|
|
1845
1949
|
|
|
1950
|
+
|
|
1846
1951
|
def __infer__(self, str_, *var):
|
|
1847
|
-
|
|
1952
|
+
def check_variable(str_, var):
|
|
1953
|
+
if _check_contains_variable(str_['dtype'], str_['value']):
|
|
1954
|
+
return True
|
|
1955
|
+
|
|
1956
|
+
for item in var:
|
|
1957
|
+
if _check_contains_variable(item['dtype'], item['value']):
|
|
1958
|
+
return True
|
|
1959
|
+
return False
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
if check_variable(str_, var):
|
|
1963
|
+
return {'dtype': mstype.string, 'shape': [], 'value': None}
|
|
1964
|
+
|
|
1965
|
+
|
|
1966
|
+
str_value = str_['value']
|
|
1967
|
+
kwargs = dict()
|
|
1848
1968
|
var_value = list()
|
|
1849
|
-
|
|
1850
|
-
raise ValueError("str.format not support to input a variable.")
|
|
1969
|
+
|
|
1851
1970
|
for item in var:
|
|
1852
|
-
if item["
|
|
1853
|
-
|
|
1971
|
+
if isinstance(item["dtype"], typing.Keyword):
|
|
1972
|
+
kwargs.update(item["value"])
|
|
1854
1973
|
var_value.append(item["value"])
|
|
1855
|
-
|
|
1974
|
+
|
|
1975
|
+
value = str_value.format(*var_value, **kwargs)
|
|
1856
1976
|
return {'dtype': mstype.string, 'shape': [], 'value': value}
|
|
1857
1977
|
|
|
1858
1978
|
|
|
@@ -2027,13 +2147,14 @@ class ClipByNorm(PrimitiveWithInfer):
|
|
|
2027
2147
|
@prim_attr_register
|
|
2028
2148
|
def __init__(self, axis=None):
|
|
2029
2149
|
"""Initialize ClipByNorm"""
|
|
2150
|
+
self.axis_str = 'axis'
|
|
2030
2151
|
self.axis = () if axis is None else axis
|
|
2031
|
-
validator.check_value_type(
|
|
2152
|
+
validator.check_value_type(self.axis_str, self.axis, [int, tuple, list], self.name)
|
|
2032
2153
|
axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
|
2033
2154
|
for i, value in enumerate(axis_check):
|
|
2034
2155
|
validator.check_value_type('axis[%d]' % i, value, [int], self.name)
|
|
2035
|
-
self.init_attrs[
|
|
2036
|
-
self.add_prim_attr(
|
|
2156
|
+
self.init_attrs[self.axis_str] = self.axis
|
|
2157
|
+
self.add_prim_attr(self.axis_str, self.axis)
|
|
2037
2158
|
self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
|
|
2038
2159
|
|
|
2039
2160
|
def infer_shape(self, x_shape, clip_norm_shape):
|
|
@@ -2588,3 +2709,215 @@ class IsConstant(Primitive):
|
|
|
2588
2709
|
|
|
2589
2710
|
def __call__(self, x):
|
|
2590
2711
|
return True
|
|
2712
|
+
|
|
2713
|
+
|
|
2714
|
+
class SelectView(Primitive):
|
|
2715
|
+
r"""
|
|
2716
|
+
Select tensor of view
|
|
2717
|
+
"""
|
|
2718
|
+
|
|
2719
|
+
@prim_attr_register
|
|
2720
|
+
def __init__(self):
|
|
2721
|
+
self.init_prim_io_names(inputs=['input_tensor', 'input_indices', 'axis'], outputs=['output'])
|
|
2722
|
+
|
|
2723
|
+
|
|
2724
|
+
class CopyWithSlice(Primitive):
|
|
2725
|
+
r"""
|
|
2726
|
+
Copy data to discontinuous tensor
|
|
2727
|
+
"""
|
|
2728
|
+
@prim_attr_register
|
|
2729
|
+
def __init__(self):
|
|
2730
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
2731
|
+
self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
|
|
2732
|
+
|
|
2733
|
+
|
|
2734
|
+
class FFN(Primitive):
|
|
2735
|
+
r"""
|
|
2736
|
+
The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
|
|
2737
|
+
|
|
2738
|
+
Args:
|
|
2739
|
+
activation (string): The activation type, set to 'fastgelu' or 'gelu'.
|
|
2740
|
+
Only support 'fastgelu' for now. Default: "fastgelu".
|
|
2741
|
+
inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
|
|
2742
|
+
Only support 1 for now. Default: 0.
|
|
2743
|
+
|
|
2744
|
+
Inputs:
|
|
2745
|
+
- **x** (Tensor) - The input tensor with data type of int8, float16.
|
|
2746
|
+
Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
|
|
2747
|
+
- **weight1** (Tensor) - The weight1 tensor with data type of float16.
|
|
2748
|
+
Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
|
|
2749
|
+
- **weight2** (Tensor) - The weight2 tensor with data type of float16.
|
|
2750
|
+
Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
|
|
2751
|
+
- **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
|
|
2752
|
+
Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
|
|
2753
|
+
indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
|
|
2754
|
+
the 2th expert do noting and so on.
|
|
2755
|
+
- **bias1** (Tensor) - The bias1 tensor with data type of float16.
|
|
2756
|
+
Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
|
|
2757
|
+
- **bias2** (Tensor) - The bias2 tensor with data type of float16.
|
|
2758
|
+
Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
|
|
2759
|
+
- **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
|
|
2760
|
+
- **offset** (Tensor) - The offset tensor with data type of float16. Not enable now.
|
|
2761
|
+
- **deq_scale1** (Tensor) - The deq_scale1 tensor with data type of float16. Not enable now.
|
|
2762
|
+
- **deq_scale2** (Tensor) - The deq_scale2 tensor with data type of float16. Not enable now.
|
|
2763
|
+
|
|
2764
|
+
Outputs:
|
|
2765
|
+
Tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`. With data type of float16.
|
|
2766
|
+
|
|
2767
|
+
Supported Platforms:
|
|
2768
|
+
``Ascend``
|
|
2769
|
+
|
|
2770
|
+
Examples:
|
|
2771
|
+
>>> from mindspore.ops.operations import _inner_ops
|
|
2772
|
+
>>> b = 4
|
|
2773
|
+
>>> s = 128
|
|
2774
|
+
>>> h = 1024
|
|
2775
|
+
>>> h_f = 4 * h
|
|
2776
|
+
>>> e = 16
|
|
2777
|
+
>>> x = Tensor(np.random.randn(b * s, h).astype(np.float16))
|
|
2778
|
+
>>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
|
|
2779
|
+
>>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
|
|
2780
|
+
>>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
|
|
2781
|
+
>>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
|
|
2782
|
+
>>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
|
|
2783
|
+
>>> ffn = _inner_ops.FFN("fastgelu", 1)
|
|
2784
|
+
>>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
|
|
2785
|
+
>>> print(output)
|
|
2786
|
+
"""
|
|
2787
|
+
|
|
2788
|
+
@prim_attr_register
|
|
2789
|
+
def __init__(self, activation, inner_precise):
|
|
2790
|
+
"""Initialize FFN."""
|
|
2791
|
+
self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
|
|
2792
|
+
"bias2", "scale", "offset", "deq_scale1", "deq_scale2"],
|
|
2793
|
+
outputs=["y"])
|
|
2794
|
+
cls_name = self.name
|
|
2795
|
+
validator.check_value_type("activation", activation, [str], cls_name)
|
|
2796
|
+
validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
|
|
2797
|
+
|
|
2798
|
+
|
|
2799
|
+
class DecoderKVCache(Primitive):
|
|
2800
|
+
r"""
|
|
2801
|
+
The DecoderKVCache is used for decoding the KVCache of transformer network.
|
|
2802
|
+
|
|
2803
|
+
Args:
|
|
2804
|
+
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
|
|
2805
|
+
When seq_len_axis is 2, cache tensor of shape
|
|
2806
|
+
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
|
|
2807
|
+
When seq_len_axis is 1, cache tensor of shape
|
|
2808
|
+
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
|
|
2809
|
+
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
|
|
2810
|
+
When seq_len_axis is 2, update tensor of shape
|
|
2811
|
+
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
|
|
2812
|
+
When seq_len_axis is 1, update tensor of shape
|
|
2813
|
+
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
|
|
2814
|
+
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
|
|
2815
|
+
Valid_seq_len tensor of shape :math:`(batch\_size)`.
|
|
2816
|
+
batch_index (Tensor): The batch_index tensor with data type of int64.
|
|
2817
|
+
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
|
|
2818
|
+
seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
|
|
2819
|
+
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2820
|
+
New_max_seq_len tensor of shape :math:`(1)`.
|
|
2821
|
+
Indicate that user want to change the shape of cache tensor from
|
|
2822
|
+
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
|
|
2823
|
+
:math:
|
|
2824
|
+
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
|
|
2825
|
+
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
|
|
2826
|
+
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2827
|
+
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
|
|
2828
|
+
|
|
2829
|
+
Outputs:
|
|
2830
|
+
With same data type and same shape as `cache` tensor.
|
|
2831
|
+
|
|
2832
|
+
Supported Platforms:
|
|
2833
|
+
``Ascend``
|
|
2834
|
+
|
|
2835
|
+
Examples:
|
|
2836
|
+
>>> from mindspore.ops.operations import _inner_ops
|
|
2837
|
+
>>> b = 4
|
|
2838
|
+
>>> h = 40
|
|
2839
|
+
>>> max_s = 1024
|
|
2840
|
+
>>> s = 1
|
|
2841
|
+
>>> d = 128
|
|
2842
|
+
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
|
|
2843
|
+
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
|
|
2844
|
+
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
|
|
2845
|
+
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
|
|
2846
|
+
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2847
|
+
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2848
|
+
>>> decoder_kv_cache = _inner_ops.DecoderKVCache()
|
|
2849
|
+
>>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
|
|
2850
|
+
>>> print(cache)
|
|
2851
|
+
"""
|
|
2852
|
+
@prim_attr_register
|
|
2853
|
+
def __init__(self):
|
|
2854
|
+
"""Initialize DecoderKVCache."""
|
|
2855
|
+
self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
|
|
2856
|
+
"new_max_seq_len", "cur_max_seq_len"],
|
|
2857
|
+
outputs=["out"])
|
|
2858
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
2859
|
+
|
|
2860
|
+
|
|
2861
|
+
class PromptKVCache(Primitive):
|
|
2862
|
+
r"""
|
|
2863
|
+
The PromptKVCache is used for prefill the KVCache of transformer network.
|
|
2864
|
+
|
|
2865
|
+
Args:
|
|
2866
|
+
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
|
|
2867
|
+
When seq_len_axis is 2, cache tensor of shape
|
|
2868
|
+
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
|
|
2869
|
+
When seq_len_axis is 1, cache tensor of shape
|
|
2870
|
+
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
|
|
2871
|
+
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
|
|
2872
|
+
When seq_len_axis is 2, update tensor of shape
|
|
2873
|
+
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
|
|
2874
|
+
When seq_len_axis is 1, update tensor of shape
|
|
2875
|
+
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
|
|
2876
|
+
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
|
|
2877
|
+
Valid_seq_len tensor of shape :math:`(batch\_size)`.
|
|
2878
|
+
batch_index (Tensor): The batch_index tensor with data type of int64.
|
|
2879
|
+
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
|
|
2880
|
+
seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
|
|
2881
|
+
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2882
|
+
New_max_seq_len tensor of shape :math:`(1)`.
|
|
2883
|
+
Indicate that user want to change the shape of cache tensor from
|
|
2884
|
+
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
|
|
2885
|
+
:math:
|
|
2886
|
+
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
|
|
2887
|
+
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
|
|
2888
|
+
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2889
|
+
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
|
|
2890
|
+
align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
|
|
2891
|
+
|
|
2892
|
+
Outputs:
|
|
2893
|
+
With same data type and same shape as `cache` tensor.
|
|
2894
|
+
|
|
2895
|
+
Supported Platforms:
|
|
2896
|
+
``Ascend``
|
|
2897
|
+
|
|
2898
|
+
Examples:
|
|
2899
|
+
>>> from mindspore import Tensor
|
|
2900
|
+
>>> from mindspore.ops.operations import _inner_ops
|
|
2901
|
+
>>> b = 4
|
|
2902
|
+
>>> h = 40
|
|
2903
|
+
>>> max_s = 1024
|
|
2904
|
+
>>> s = 256
|
|
2905
|
+
>>> d = 128
|
|
2906
|
+
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
|
|
2907
|
+
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
|
|
2908
|
+
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
|
|
2909
|
+
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
|
|
2910
|
+
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2911
|
+
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2912
|
+
>>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
|
|
2913
|
+
>>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
|
|
2914
|
+
>>> print(cache)
|
|
2915
|
+
"""
|
|
2916
|
+
@prim_attr_register
|
|
2917
|
+
def __init__(self, padding_mode="right"):
|
|
2918
|
+
"""Initialize PromptKVCache."""
|
|
2919
|
+
self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
|
|
2920
|
+
"new_max_seq_len", "cur_max_seq_len"],
|
|
2921
|
+
outputs=["out"])
|
|
2922
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
2923
|
+
self.padding_mode = padding_mode
|
|
@@ -269,7 +269,7 @@ class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer):
|
|
|
269
269
|
- **quant_max** (Tensor) : Value of the quantization range.
|
|
270
270
|
|
|
271
271
|
Outputs:
|
|
272
|
-
- Tensor: Simulates quantize tensor of `input_x
|
|
272
|
+
- Tensor: Simulates quantize tensor of `input_x`, with the same type and shape as the `input_x`.
|
|
273
273
|
|
|
274
274
|
Examples:
|
|
275
275
|
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
|
@@ -419,7 +419,7 @@ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
|
|
|
419
419
|
- **quant_max** (Tensor) : Value of the quantization range.
|
|
420
420
|
|
|
421
421
|
Outputs:
|
|
422
|
-
- Tensor: Simulates quantize tensor of `input_x
|
|
422
|
+
- Tensor: Simulates quantize tensor of `input_x`, with the same type and shape as the `input_x`.
|
|
423
423
|
|
|
424
424
|
Examples:
|
|
425
425
|
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
|
|
@@ -975,7 +975,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|
|
975
975
|
>>> result = fake_quant(input_x, _min, _max)
|
|
976
976
|
"""
|
|
977
977
|
support_quant_bit = [4, 7, 8]
|
|
978
|
-
ascend_support_x_rank = [2, 4]
|
|
978
|
+
ascend_support_x_rank = [2, 3, 4]
|
|
979
979
|
|
|
980
980
|
@prim_attr_register
|
|
981
981
|
def __init__(self,
|
|
@@ -1008,11 +1008,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|
|
1008
1008
|
self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
|
|
1009
1009
|
self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
|
|
1010
1010
|
self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
|
|
1011
|
-
|
|
1012
|
-
self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
|
|
1013
|
-
'channel_axis', self.name)
|
|
1014
|
-
else:
|
|
1015
|
-
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
|
|
1011
|
+
self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
|
|
1016
1012
|
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
|
1017
1013
|
|
|
1018
1014
|
def infer_shape(self, x_shape, min_shape, max_shape):
|