mindspore 2.2.14__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +8 -5
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +124 -25
- mindspore/_extends/builtin_operations.py +2 -1
- mindspore/_extends/graph_kernel/model/graph_parallel.py +16 -6
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +3 -16
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +16 -4
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +1 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +2 -1
- mindspore/_extends/parallel_compile/akg_compiler/util.py +5 -2
- mindspore/_extends/parse/__init__.py +18 -14
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +2 -2
- mindspore/_extends/parse/parser.py +182 -68
- mindspore/_extends/parse/resources.py +45 -14
- mindspore/_extends/parse/standard_method.py +192 -252
- mindspore/{ops/_op_impl/tbe/atomic_addr_clean.py → _extends/pijit/__init__.py} +6 -16
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/kernel_build_server.py +2 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +67 -26
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/base.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +2 -2
- mindspore/boost/grad_freeze.py +2 -2
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/boost/less_batch_normalization.py +9 -6
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +20 -7
- mindspore/common/_jit_fallback_utils.py +2 -3
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +7 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +10 -10
- mindspore/common/_stub_tensor.py +7 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +5 -17
- mindspore/common/api.py +449 -129
- mindspore/common/auto_dynamic_shape.py +27 -14
- mindspore/common/dtype.py +17 -10
- mindspore/common/dump.py +8 -11
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +65 -30
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +34 -14
- mindspore/common/lazy_inline.py +72 -19
- mindspore/common/mindir_util.py +12 -2
- mindspore/common/mutable.py +79 -14
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +73 -21
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +9 -9
- mindspore/common/sparse_tensor.py +276 -24
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +668 -514
- mindspore/communication/__init__.py +6 -11
- mindspore/communication/_comm_helper.py +43 -3
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +117 -104
- mindspore/config/op_info.config +22 -54
- mindspore/context.py +455 -71
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +6 -6
- mindspore/dataset/audio/transforms.py +711 -158
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +2 -2
- mindspore/dataset/engine/datasets.py +201 -116
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +83 -3
- mindspore/dataset/engine/datasets_text.py +39 -39
- mindspore/dataset/engine/datasets_user_defined.py +230 -141
- mindspore/dataset/engine/datasets_vision.py +78 -74
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/offload.py +5 -7
- mindspore/dataset/engine/queue.py +138 -66
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +41 -15
- mindspore/dataset/text/__init__.py +2 -5
- mindspore/dataset/text/transforms.py +408 -121
- mindspore/dataset/text/utils.py +9 -9
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/transforms/transforms.py +261 -76
- mindspore/dataset/utils/browse_dataset.py +9 -9
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +7 -10
- mindspore/dataset/vision/c_transforms.py +10 -10
- mindspore/dataset/vision/py_transforms_util.py +1 -1
- mindspore/dataset/vision/transforms.py +2844 -549
- mindspore/dataset/vision/utils.py +161 -10
- mindspore/dataset/vision/validators.py +16 -3
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{rewrite/ast_creator_register.py → experimental/es/__init__.py} +5 -20
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/__init__.py +12 -2
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +35 -34
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +124 -15
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +66 -121
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +18 -8
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +28 -19
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/api/data_type.h +2 -2
- mindspore/include/api/dual_abi_helper.h +16 -3
- mindspore/include/api/model.h +4 -3
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/status.h +14 -0
- mindspore/include/api/types.h +10 -10
- mindspore/include/c_api/model_c.h +173 -0
- mindspore/include/c_api/types_c.h +19 -0
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +3 -5
- mindspore/include/dataset/vision.h +58 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +3 -3
- mindspore/mindrecord/__init__.py +5 -1
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +25 -0
- mindspore/mindrecord/filewriter.py +138 -103
- mindspore/mindrecord/mindpage.py +40 -6
- mindspore/mindrecord/shardutils.py +3 -2
- mindspore/mindrecord/shardwriter.py +7 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +8 -13
- mindspore/mindrecord/tools/cifar10_to_mr.py +9 -15
- mindspore/mindrecord/tools/csv_to_mr.py +4 -9
- mindspore/mindrecord/tools/imagenet_to_mr.py +3 -8
- mindspore/mindrecord/tools/mnist_to_mr.py +7 -12
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -6
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/{rewrite/ast_transformers → mint/linalg}/__init__.py +9 -4
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/cell.py +461 -323
- mindspore/nn/dynamic_lr.py +2 -2
- mindspore/nn/layer/activation.py +292 -135
- mindspore/nn/layer/basic.py +288 -83
- mindspore/nn/layer/channel_shuffle.py +3 -16
- mindspore/nn/layer/container.py +3 -3
- mindspore/nn/layer/conv.py +75 -66
- mindspore/nn/layer/embedding.py +221 -45
- mindspore/nn/layer/image.py +4 -7
- mindspore/nn/layer/math.py +1 -1
- mindspore/nn/layer/normalization.py +150 -68
- mindspore/nn/layer/padding.py +64 -87
- mindspore/nn/layer/pooling.py +175 -12
- mindspore/nn/layer/rnn_cells.py +6 -16
- mindspore/nn/layer/rnns.py +6 -5
- mindspore/nn/layer/thor_layer.py +1 -2
- mindspore/nn/layer/timedistributed.py +1 -1
- mindspore/nn/layer/transformer.py +55 -53
- mindspore/nn/learning_rate_schedule.py +6 -5
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +145 -88
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/ada_grad.py +4 -2
- mindspore/nn/optim/adadelta.py +4 -2
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +102 -181
- mindspore/nn/optim/adamax.py +4 -2
- mindspore/nn/optim/adasum.py +3 -3
- mindspore/nn/optim/asgd.py +4 -2
- mindspore/nn/optim/ftrl.py +31 -61
- mindspore/nn/optim/lamb.py +5 -3
- mindspore/nn/optim/lars.py +2 -2
- mindspore/nn/optim/lazyadam.py +6 -4
- mindspore/nn/optim/momentum.py +13 -25
- mindspore/nn/optim/optimizer.py +6 -3
- mindspore/nn/optim/proximal_ada_grad.py +4 -2
- mindspore/nn/optim/rmsprop.py +9 -3
- mindspore/nn/optim/rprop.py +4 -2
- mindspore/nn/optim/sgd.py +5 -3
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/custom_ops.py +2 -2
- mindspore/nn/probability/distribution/beta.py +2 -2
- mindspore/nn/probability/distribution/categorical.py +4 -6
- mindspore/nn/probability/distribution/cauchy.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -2
- mindspore/nn/probability/distribution/geometric.py +1 -1
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +1 -1
- mindspore/nn/probability/distribution/poisson.py +2 -2
- mindspore/nn/probability/distribution/uniform.py +2 -2
- mindspore/nn/reinforcement/_tensors_queue.py +13 -1
- mindspore/nn/wrap/__init__.py +2 -1
- mindspore/nn/wrap/cell_wrapper.py +46 -12
- mindspore/nn/wrap/grad_reducer.py +148 -8
- mindspore/nn/wrap/loss_scale.py +44 -7
- mindspore/numpy/__init__.py +2 -0
- mindspore/numpy/array_creations.py +67 -68
- mindspore/numpy/array_ops.py +70 -66
- mindspore/numpy/dtypes.py +3 -3
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +147 -152
- mindspore/numpy/utils.py +3 -0
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +9 -6
- mindspore/ops/_grad_experimental/grad_array_ops.py +4 -129
- mindspore/ops/_grad_experimental/grad_comm_ops.py +135 -36
- mindspore/ops/_grad_experimental/grad_math_ops.py +61 -298
- mindspore/ops/_grad_experimental/grad_nn_ops.py +0 -53
- mindspore/ops/_grad_experimental/grad_quant_ops.py +3 -3
- mindspore/ops/_grad_experimental/grad_sparse.py +1 -1
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/__init__.py +0 -1
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +1 -1
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +1 -3
- mindspore/ops/_op_impl/aicpu/poisson.py +2 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -3
- mindspore/ops/_op_impl/cpu/adam.py +2 -2
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +3 -2
- mindspore/ops/_op_impl/cpu/maximum_grad.py +16 -14
- mindspore/ops/_op_impl/cpu/minimum_grad.py +8 -0
- mindspore/ops/_vmap/vmap_array_ops.py +162 -101
- mindspore/ops/_vmap/vmap_base.py +8 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +95 -9
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +143 -58
- mindspore/ops/_vmap/vmap_image_ops.py +70 -13
- mindspore/ops/_vmap/vmap_math_ops.py +147 -59
- mindspore/ops/_vmap/vmap_nn_ops.py +292 -117
- mindspore/ops/_vmap/vmap_other_ops.py +1 -1
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +201 -66
- mindspore/ops/composite/math_ops.py +10 -49
- mindspore/ops/composite/multitype_ops/_compile_utils.py +192 -618
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +25 -134
- mindspore/ops/composite/multitype_ops/add_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +8 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +8 -2
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +9 -3
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +8 -3
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/pow_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +32 -21
- mindspore/ops/composite/multitype_ops/sub_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +6 -3
- mindspore/ops/deprecated.py +14 -3
- mindspore/ops/function/__init__.py +53 -11
- mindspore/ops/function/array_func.py +1269 -1821
- mindspore/ops/function/clip_func.py +19 -31
- mindspore/ops/function/debug_func.py +114 -5
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/grad_func.py +30 -22
- mindspore/ops/function/image_func.py +27 -21
- mindspore/ops/function/linalg_func.py +35 -68
- mindspore/ops/function/math_func.py +1170 -2697
- mindspore/ops/function/nn_func.py +2116 -1128
- mindspore/ops/function/other_func.py +8 -8
- mindspore/ops/function/parameter_func.py +5 -93
- mindspore/ops/function/random_func.py +435 -113
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +4 -4
- mindspore/ops/function/sparse_unary_func.py +9 -16
- mindspore/ops/function/spectral_func.py +1 -1
- mindspore/ops/function/vmap_func.py +16 -15
- mindspore/ops/functional.py +355 -346
- mindspore/ops/op_info_register.py +18 -45
- mindspore/ops/operations/__init__.py +38 -24
- mindspore/ops/operations/_grad_ops.py +21 -927
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +173 -607
- mindspore/ops/operations/_rl_inner_ops.py +2 -2
- mindspore/ops/operations/_scalar_ops.py +5 -480
- mindspore/ops/operations/_sequence_ops.py +6 -36
- mindspore/ops/operations/_tensor_array.py +8 -8
- mindspore/ops/operations/array_ops.py +106 -2837
- mindspore/ops/operations/comm_ops.py +799 -127
- mindspore/ops/operations/custom_ops.py +124 -119
- mindspore/ops/operations/debug_ops.py +142 -41
- mindspore/ops/operations/image_ops.py +1 -217
- mindspore/ops/operations/inner_ops.py +5 -40
- mindspore/ops/operations/linalg_ops.py +1 -49
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +666 -4972
- mindspore/ops/operations/nn_ops.py +205 -2213
- mindspore/ops/operations/other_ops.py +60 -49
- mindspore/ops/operations/random_ops.py +50 -54
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/sparse_ops.py +4 -4
- mindspore/ops/primitive.py +216 -103
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +8 -4
- mindspore/parallel/_auto_parallel_context.py +100 -10
- mindspore/parallel/_cell_wrapper.py +99 -9
- mindspore/parallel/_cost_model_context.py +1 -1
- mindspore/parallel/_dp_allreduce_fusion.py +159 -159
- mindspore/parallel/_parallel_serialization.py +67 -23
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +1 -1
- mindspore/parallel/_tensor.py +99 -22
- mindspore/parallel/_transformer/__init__.py +1 -1
- mindspore/parallel/_transformer/layers.py +1 -1
- mindspore/parallel/_transformer/loss.py +1 -1
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/op_parallel_config.py +1 -1
- mindspore/parallel/_transformer/transformer.py +2 -2
- mindspore/parallel/_utils.py +173 -6
- mindspore/parallel/algo_parameter_config.py +8 -10
- mindspore/parallel/checkpoint_transform.py +204 -38
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +1 -1
- mindspore/parallel/mpi/_mpi_config.py +1 -1
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +279 -37
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +4 -2
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/util.py +153 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +18 -20
- mindspore/{_extends/parallel_compile/tbe_compiler → profiler/parser/ascend_analysis}/__init__.py +1 -1
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +14 -9
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -1
- mindspore/profiler/parser/ascend_flops_generator.py +20 -4
- mindspore/profiler/parser/ascend_hccl_generator.py +29 -278
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +148 -146
- mindspore/profiler/parser/ascend_msprof_generator.py +73 -283
- mindspore/profiler/parser/ascend_op_generator.py +92 -42
- mindspore/profiler/parser/ascend_timeline_generator.py +298 -133
- mindspore/profiler/parser/base_timeline_generator.py +25 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +4 -393
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/integrator.py +3 -1
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/minddata_parser.py +72 -3
- mindspore/profiler/parser/profiler_info.py +94 -7
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +631 -508
- mindspore/rewrite/__init__.py +2 -14
- mindspore/rewrite/api/node.py +122 -36
- mindspore/rewrite/api/pattern_engine.py +2 -3
- mindspore/rewrite/api/scoped_value.py +16 -15
- mindspore/rewrite/api/symbol_tree.py +45 -29
- mindspore/rewrite/ast_helpers/__init__.py +3 -6
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +48 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +160 -92
- mindspore/rewrite/common/__init__.py +1 -2
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/{rewrite_elog.py → error_log.py} +39 -39
- mindspore/rewrite/{namer.py → common/namer.py} +63 -18
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/node/__init__.py +5 -5
- mindspore/rewrite/node/call_function.py +23 -7
- mindspore/rewrite/node/cell_container.py +7 -3
- mindspore/rewrite/node/control_flow.py +53 -28
- mindspore/rewrite/node/node.py +212 -196
- mindspore/rewrite/node/node_manager.py +51 -22
- mindspore/rewrite/node/node_topological_manager.py +3 -23
- mindspore/rewrite/parsers/__init__.py +12 -0
- mindspore/rewrite/parsers/arguments_parser.py +8 -9
- mindspore/rewrite/parsers/assign_parser.py +637 -413
- mindspore/rewrite/parsers/attribute_parser.py +3 -4
- mindspore/rewrite/parsers/class_def_parser.py +115 -148
- mindspore/rewrite/parsers/constant_parser.py +5 -5
- mindspore/rewrite/parsers/container_parser.py +4 -6
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +31 -98
- mindspore/rewrite/parsers/function_def_parser.py +13 -5
- mindspore/rewrite/parsers/if_parser.py +28 -10
- mindspore/rewrite/parsers/module_parser.py +8 -182
- mindspore/rewrite/parsers/parser.py +1 -5
- mindspore/rewrite/parsers/parser_register.py +1 -1
- mindspore/rewrite/parsers/return_parser.py +5 -10
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/{symbol_tree.py → symbol_tree/symbol_tree.py} +705 -186
- mindspore/rewrite/{symbol_tree_builder.py → symbol_tree/symbol_tree_builder.py} +8 -8
- mindspore/rewrite/{symbol_tree_dumper.py → symbol_tree/symbol_tree_dumper.py} +4 -4
- mindspore/run_check/_check_version.py +40 -115
- mindspore/run_check/run_check.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +597 -263
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +7 -5
- mindspore/train/_utils.py +204 -4
- mindspore/train/amp.py +335 -295
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/__init__.py +5 -2
- mindspore/train/callback/_backup_and_restore.py +5 -5
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +220 -43
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_landscape.py +15 -9
- mindspore/train/callback/_loss_monitor.py +5 -5
- mindspore/train/callback/_on_request_exit.py +136 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +12 -12
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +3 -3
- mindspore/train/data_sink.py +6 -5
- mindspore/train/dataset_helper.py +66 -23
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/accuracy.py +7 -7
- mindspore/train/metrics/confusion_matrix.py +8 -6
- mindspore/train/metrics/cosine_similarity.py +6 -4
- mindspore/train/metrics/error.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/perplexity.py +2 -1
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/metrics/topk.py +2 -2
- mindspore/train/mind_ir_pb2.py +116 -37
- mindspore/train/model.py +382 -76
- mindspore/train/serialization.py +787 -288
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/summary_record.py +51 -28
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/METADATA +8 -4
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +1 -0
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +0 -662
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +0 -377
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +0 -201
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +0 -515
- mindspore/gen_ops.py +0 -273
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -282
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/layer/flash_attention.py +0 -189
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/_op_impl/cpu/concat.py +0 -39
- mindspore/ops/_op_impl/cpu/tensor_shape.py +0 -42
- mindspore/ops/_op_impl/tbe/__init__.py +0 -47
- mindspore/ops/_op_impl/tbe/abs.py +0 -38
- mindspore/ops/_op_impl/tbe/abs_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/abs_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/abs_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/accumulate_n_v2_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/acos.py +0 -37
- mindspore/ops/_op_impl/tbe/acos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acos_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acos_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/acosh.py +0 -37
- mindspore/ops/_op_impl/tbe/acosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/acosh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/acosh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_max_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/act_ulq_clamp_min_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/acts_ulq.py +0 -45
- mindspore/ops/_op_impl/tbe/acts_ulq_input_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/adam_apply_one.py +0 -50
- mindspore/ops/_op_impl/tbe/adam_apply_one_assign.py +0 -53
- mindspore/ops/_op_impl/tbe/adam_apply_one_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_assign.py +0 -54
- mindspore/ops/_op_impl/tbe/adam_apply_one_with_decay_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/adaptive_max_pool2d.py +0 -37
- mindspore/ops/_op_impl/tbe/add.py +0 -42
- mindspore/ops/_op_impl/tbe/add_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/add_n.py +0 -39
- mindspore/ops/_op_impl/tbe/add_n_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/addcdiv.py +0 -41
- mindspore/ops/_op_impl/tbe/addcdiv_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/addcmul.py +0 -43
- mindspore/ops/_op_impl/tbe/addcmul_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_ada_max.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_ada_max_ds.py +0 -69
- mindspore/ops/_op_impl/tbe/apply_adadelta.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_adadelta_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_adagrad_d_a.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_adagrad_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/apply_adagrad_v2_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_adam.py +0 -79
- mindspore/ops/_op_impl/tbe/apply_adam_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad.py +0 -60
- mindspore/ops/_op_impl/tbe/apply_adam_with_amsgrad_ds.py +0 -61
- mindspore/ops/_op_impl/tbe/apply_add_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_add_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop.py +0 -77
- mindspore/ops/_op_impl/tbe/apply_centered_rms_prop_ds.py +0 -78
- mindspore/ops/_op_impl/tbe/apply_ftrl.py +0 -67
- mindspore/ops/_op_impl/tbe/apply_ftrl_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/apply_gradient_descent.py +0 -44
- mindspore/ops/_op_impl/tbe/apply_gradient_descent_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/apply_keras_momentum.py +0 -49
- mindspore/ops/_op_impl/tbe/apply_momentum.py +0 -64
- mindspore/ops/_op_impl/tbe/apply_momentum_ds.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign.py +0 -65
- mindspore/ops/_op_impl/tbe/apply_power_sign_ds.py +0 -66
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad.py +0 -57
- mindspore/ops/_op_impl/tbe/apply_proximal_adagrad_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent.py +0 -54
- mindspore/ops/_op_impl/tbe/apply_proximal_gradient_descent_ds.py +0 -55
- mindspore/ops/_op_impl/tbe/apply_rms_prop.py +0 -52
- mindspore/ops/_op_impl/tbe/approximate_equal.py +0 -39
- mindspore/ops/_op_impl/tbe/approximate_equal_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_max.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_max_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/arg_min.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_v2_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/arg_min_with_value.py +0 -38
- mindspore/ops/_op_impl/tbe/arg_min_with_value_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/asin.py +0 -37
- mindspore/ops/_op_impl/tbe/asin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asin_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asin_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/asinh.py +0 -37
- mindspore/ops/_op_impl/tbe/asinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/asinh_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/asinh_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/assign.py +0 -79
- mindspore/ops/_op_impl/tbe/assign_add.py +0 -59
- mindspore/ops/_op_impl/tbe/assign_add_ds.py +0 -60
- mindspore/ops/_op_impl/tbe/assign_ds.py +0 -80
- mindspore/ops/_op_impl/tbe/assign_sub.py +0 -55
- mindspore/ops/_op_impl/tbe/assign_sub_ds.py +0 -56
- mindspore/ops/_op_impl/tbe/atan.py +0 -37
- mindspore/ops/_op_impl/tbe/atan2.py +0 -38
- mindspore/ops/_op_impl/tbe/atan2_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/atan_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/atan_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/atan_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/atanh.py +0 -37
- mindspore/ops/_op_impl/tbe/atanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/avg_pool.py +0 -43
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +0 -45
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell.py +0 -57
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +0 -50
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad_v2.py +0 -51
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_input_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/basic_lstm_cell_weight_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul.py +0 -42
- mindspore/ops/_op_impl/tbe/batch_matmul_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/batch_matmul_v2.py +0 -47
- mindspore/ops/_op_impl/tbe/batch_to_space.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/batch_to_space_nd_v2.py +0 -41
- mindspore/ops/_op_impl/tbe/batchnorm.py +0 -58
- mindspore/ops/_op_impl/tbe/batchnorm_grad.py +0 -58
- mindspore/ops/_op_impl/tbe/bce_with_logits_loss.py +0 -42
- mindspore/ops/_op_impl/tbe/bessel_i0e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i0e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bessel_i1e.py +0 -37
- mindspore/ops/_op_impl/tbe/bessel_i1e_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add.py +0 -38
- mindspore/ops/_op_impl/tbe/bias_add_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bias_add_grad.py +0 -53
- mindspore/ops/_op_impl/tbe/binary_cross_entropy.py +0 -39
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/binary_cross_entropy_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bitwise_and.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_and_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_or.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_or_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bitwise_xor.py +0 -39
- mindspore/ops/_op_impl/tbe/bitwise_xor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_infer.py +0 -43
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_infer_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/bn_infer_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/bn_inference.py +0 -50
- mindspore/ops/_op_impl/tbe/bn_training_reduce.py +0 -38
- mindspore/ops/_op_impl/tbe/bn_training_reduce_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/bn_training_reduce_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/bn_training_update.py +0 -52
- mindspore/ops/_op_impl/tbe/bn_training_update_ds.py +0 -53
- mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/bn_training_update_grad_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/bn_training_update_v2.py +0 -48
- mindspore/ops/_op_impl/tbe/bn_training_update_v3.py +0 -51
- mindspore/ops/_op_impl/tbe/bounding_box_decode.py +0 -41
- mindspore/ops/_op_impl/tbe/bounding_box_decode_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/bounding_box_encode.py +0 -38
- mindspore/ops/_op_impl/tbe/broadcast_to.py +0 -40
- mindspore/ops/_op_impl/tbe/broadcast_to_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cast.py +0 -55
- mindspore/ops/_op_impl/tbe/cast_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/cdist.py +0 -38
- mindspore/ops/_op_impl/tbe/cdist_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/ceil.py +0 -37
- mindspore/ops/_op_impl/tbe/ceil_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/celu.py +0 -39
- mindspore/ops/_op_impl/tbe/centralization.py +0 -39
- mindspore/ops/_op_impl/tbe/check_valid.py +0 -38
- mindspore/ops/_op_impl/tbe/check_valid_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_norm_no_div_sum_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/clip_by_value.py +0 -41
- mindspore/ops/_op_impl/tbe/clip_by_value_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/concat.py +0 -40
- mindspore/ops/_op_impl/tbe/concat_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/confusion_matrix.py +0 -63
- mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/confusion_softmax_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/confusion_transpose_d.py +0 -39
- mindspore/ops/_op_impl/tbe/conv2d.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv2d_backprop_input_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/conv2d_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/conv2d_transpose.py +0 -48
- mindspore/ops/_op_impl/tbe/conv3d.py +0 -45
- mindspore/ops/_op_impl/tbe/conv3d_backprop_filter.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_backprop_input.py +0 -42
- mindspore/ops/_op_impl/tbe/conv3d_transpose.py +0 -47
- mindspore/ops/_op_impl/tbe/conv3d_transpose_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/cos.py +0 -37
- mindspore/ops/_op_impl/tbe/cos_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/cosh.py +0 -37
- mindspore/ops/_op_impl/tbe/cosh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/ctc_loss_v2.py +0 -42
- mindspore/ops/_op_impl/tbe/ctc_loss_v2_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/cum_sum.py +0 -42
- mindspore/ops/_op_impl/tbe/cum_sum_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/cummin.py +0 -41
- mindspore/ops/_op_impl/tbe/cumprod.py +0 -42
- mindspore/ops/_op_impl/tbe/data_format_dim_map.py +0 -38
- mindspore/ops/_op_impl/tbe/data_format_dim_map_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/deformable_offsets.py +0 -45
- mindspore/ops/_op_impl/tbe/deformable_offsets_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/depth_to_space_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +0 -44
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +0 -41
- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +0 -41
- mindspore/ops/_op_impl/tbe/diag.py +0 -38
- mindspore/ops/_op_impl/tbe/diag_part.py +0 -38
- mindspore/ops/_op_impl/tbe/dilation.py +0 -40
- mindspore/ops/_op_impl/tbe/div.py +0 -41
- mindspore/ops/_op_impl/tbe/div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/div_no_nan.py +0 -41
- mindspore/ops/_op_impl/tbe/div_no_nan_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/dropout_do_mask.py +0 -38
- mindspore/ops/_op_impl/tbe/dropout_do_mask_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/dynamic_atomic_addr_clean.py +0 -34
- mindspore/ops/_op_impl/tbe/dynamic_gru_v2.py +0 -95
- mindspore/ops/_op_impl/tbe/dynamic_rnn.py +0 -82
- mindspore/ops/_op_impl/tbe/elu.py +0 -38
- mindspore/ops/_op_impl/tbe/elu_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/elu_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/elu_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/equal.py +0 -42
- mindspore/ops/_op_impl/tbe/equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/erf.py +0 -37
- mindspore/ops/_op_impl/tbe/erf_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfc.py +0 -37
- mindspore/ops/_op_impl/tbe/erfc_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/erfinv.py +0 -36
- mindspore/ops/_op_impl/tbe/exp.py +0 -40
- mindspore/ops/_op_impl/tbe/exp_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/expand_dims.py +0 -38
- mindspore/ops/_op_impl/tbe/expm1.py +0 -37
- mindspore/ops/_op_impl/tbe/expm1_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/extract_image_patches.py +0 -41
- mindspore/ops/_op_impl/tbe/extract_volume_patches.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel.py +0 -39
- mindspore/ops/_op_impl/tbe/fake_quant_with_min_max_vars_per_channel_gradient.py +0 -43
- mindspore/ops/_op_impl/tbe/fast_gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/fast_gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/fast_gelu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/fast_gelu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/fill.py +0 -56
- mindspore/ops/_op_impl/tbe/fill_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/flatten.py +0 -48
- mindspore/ops/_op_impl/tbe/floor.py +0 -37
- mindspore/ops/_op_impl/tbe/floor_div.py +0 -41
- mindspore/ops/_op_impl/tbe/floor_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/floor_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/floor_mod.py +0 -39
- mindspore/ops/_op_impl/tbe/floor_mod_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/fused_dbn_dw.py +0 -52
- mindspore/ops/_op_impl/tbe/fused_mul_add.py +0 -38
- mindspore/ops/_op_impl/tbe/fused_mul_add_n.py +0 -48
- mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +0 -53
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum.py +0 -57
- mindspore/ops/_op_impl/tbe/fused_mul_apply_momentum_extern.py +0 -67
- mindspore/ops/_op_impl/tbe/gather_nd.py +0 -52
- mindspore/ops/_op_impl/tbe/gather_nd_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/gather_v2_ds.py +0 -68
- mindspore/ops/_op_impl/tbe/gelu.py +0 -37
- mindspore/ops/_op_impl/tbe/gelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/gelu_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/gelu_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/ger.py +0 -43
- mindspore/ops/_op_impl/tbe/ger_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/greater.py +0 -43
- mindspore/ops/_op_impl/tbe/greater_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/greater_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/gru_v2_hidden_grad_cell.py +0 -52
- mindspore/ops/_op_impl/tbe/hard_swish.py +0 -37
- mindspore/ops/_op_impl/tbe/hard_swish_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/hard_swish_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/hard_swish_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/histogram_fixed_width.py +0 -40
- mindspore/ops/_op_impl/tbe/hshrink.py +0 -33
- mindspore/ops/_op_impl/tbe/hshrink_grad.py +0 -37
- mindspore/ops/_op_impl/tbe/hsigmoid.py +0 -45
- mindspore/ops/_op_impl/tbe/hsigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/ifmr.py +0 -47
- mindspore/ops/_op_impl/tbe/ifmr_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/im2col.py +0 -42
- mindspore/ops/_op_impl/tbe/in_top_k.py +0 -37
- mindspore/ops/_op_impl/tbe/inplace_add.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +0 -46
- mindspore/ops/_op_impl/tbe/inplace_sub.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update.py +0 -39
- mindspore/ops/_op_impl/tbe/inplace_update_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/inv.py +0 -38
- mindspore/ops/_op_impl/tbe/inv_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/inv_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/inv_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/invert.py +0 -37
- mindspore/ops/_op_impl/tbe/invert_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/iou.py +0 -38
- mindspore/ops/_op_impl/tbe/iou_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/is_close.py +0 -40
- mindspore/ops/_op_impl/tbe/kl_div_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/kl_div_loss_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/kl_div_loss_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/l2_loss.py +0 -36
- mindspore/ops/_op_impl/tbe/l2_loss_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/l2_normalize.py +0 -38
- mindspore/ops/_op_impl/tbe/l2_normalize_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/lamb_apply_optimizer_assign.py +0 -55
- mindspore/ops/_op_impl/tbe/lamb_apply_weight_assign.py +0 -42
- mindspore/ops/_op_impl/tbe/lamb_next_mv.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py +0 -59
- mindspore/ops/_op_impl/tbe/lamb_next_right.py +0 -44
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr.py +0 -48
- mindspore/ops/_op_impl/tbe/lamb_update_with_lr_v2.py +0 -44
- mindspore/ops/_op_impl/tbe/lars_update.py +0 -50
- mindspore/ops/_op_impl/tbe/lars_update_ds.py +0 -51
- mindspore/ops/_op_impl/tbe/layer_norm.py +0 -46
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/layer_norm_beta_gamma_backprop_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/layer_norm_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/layer_norm_grad.py +0 -48
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop.py +0 -43
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2.py +0 -45
- mindspore/ops/_op_impl/tbe/layer_norm_x_backprop_v2_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/lerp.py +0 -38
- mindspore/ops/_op_impl/tbe/less.py +0 -41
- mindspore/ops/_op_impl/tbe/less_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/less_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/less_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/log.py +0 -40
- mindspore/ops/_op_impl/tbe/log1p.py +0 -37
- mindspore/ops/_op_impl/tbe/log1p_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/log_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/logical_and.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_and_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logical_not.py +0 -36
- mindspore/ops/_op_impl/tbe/logical_not_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or.py +0 -37
- mindspore/ops/_op_impl/tbe/logical_or_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax.py +0 -37
- mindspore/ops/_op_impl/tbe/logsoftmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/logsoftmax_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/lp_norm.py +0 -40
- mindspore/ops/_op_impl/tbe/lp_norm_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn.py +0 -41
- mindspore/ops/_op_impl/tbe/lrn_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/lstm_input_grad.py +0 -51
- mindspore/ops/_op_impl/tbe/masked_fill.py +0 -40
- mindspore/ops/_op_impl/tbe/masked_fill_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/matmul.py +0 -53
- mindspore/ops/_op_impl/tbe/matmul_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/matmul_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/matrix_diag.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_diag_part.py +0 -45
- mindspore/ops/_op_impl/tbe/matrix_set_diag.py +0 -46
- mindspore/ops/_op_impl/tbe/max_pool.py +0 -39
- mindspore/ops/_op_impl/tbe/max_pool3d.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool3d_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool3d_grad_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/max_pool_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/max_pool_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_grad_with_argmax.py +0 -41
- mindspore/ops/_op_impl/tbe/max_pool_grad_with_argmax.py +0 -42
- mindspore/ops/_op_impl/tbe/max_pool_with_argmax.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum.py +0 -39
- mindspore/ops/_op_impl/tbe/maximum_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/maximum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/maximum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mem_set.py +0 -38
- mindspore/ops/_op_impl/tbe/minimum.py +0 -40
- mindspore/ops/_op_impl/tbe/minimum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/minimum_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/minimum_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/mish.py +0 -37
- mindspore/ops/_op_impl/tbe/mod.py +0 -41
- mindspore/ops/_op_impl/tbe/mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/mul.py +0 -37
- mindspore/ops/_op_impl/tbe/mul_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/mul_no_nan.py +0 -39
- mindspore/ops/_op_impl/tbe/mul_no_nan_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/multilabel_margin_loss.py +0 -39
- mindspore/ops/_op_impl/tbe/neg.py +0 -39
- mindspore/ops/_op_impl/tbe/neg_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/new_im2col.py +0 -40
- mindspore/ops/_op_impl/tbe/nll_loss.py +0 -41
- mindspore/ops/_op_impl/tbe/nll_loss_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/nms_with_mask.py +0 -39
- mindspore/ops/_op_impl/tbe/not_equal.py +0 -41
- mindspore/ops/_op_impl/tbe/not_equal_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/npu_alloc_float_status.py +0 -34
- mindspore/ops/_op_impl/tbe/npu_clear_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_clear_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status.py +0 -35
- mindspore/ops/_op_impl/tbe/npu_get_float_status_v2.py +0 -35
- mindspore/ops/_op_impl/tbe/one_hot.py +0 -48
- mindspore/ops/_op_impl/tbe/one_hot_ds.py +0 -45
- mindspore/ops/_op_impl/tbe/ones_like.py +0 -40
- mindspore/ops/_op_impl/tbe/ones_like_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling.py +0 -40
- mindspore/ops/_op_impl/tbe/p_s_r_o_i_pooling_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/pack.py +0 -58
- mindspore/ops/_op_impl/tbe/pack_ds.py +0 -59
- mindspore/ops/_op_impl/tbe/pad_d.py +0 -40
- mindspore/ops/_op_impl/tbe/pad_d_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/parallel_concat.py +0 -70
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear.py +0 -45
- mindspore/ops/_op_impl/tbe/parallel_resize_bilinear_grad.py +0 -44
- mindspore/ops/_op_impl/tbe/pdist.py +0 -36
- mindspore/ops/_op_impl/tbe/pooling.py +0 -46
- mindspore/ops/_op_impl/tbe/population_count.py +0 -38
- mindspore/ops/_op_impl/tbe/pow.py +0 -41
- mindspore/ops/_op_impl/tbe/pow_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/prelu.py +0 -37
- mindspore/ops/_op_impl/tbe/prelu_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/prelu_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/range.py +0 -39
- mindspore/ops/_op_impl/tbe/real_div.py +0 -38
- mindspore/ops/_op_impl/tbe/real_div_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reciprocal.py +0 -36
- mindspore/ops/_op_impl/tbe/reciprocal_ds.py +0 -37
- mindspore/ops/_op_impl/tbe/reciprocal_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/reciprocal_grad_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_all.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_all_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_any.py +0 -38
- mindspore/ops/_op_impl/tbe/reduce_any_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_max.py +0 -43
- mindspore/ops/_op_impl/tbe/reduce_max_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_mean.py +0 -40
- mindspore/ops/_op_impl/tbe/reduce_mean_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_min.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_min_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_prod.py +0 -42
- mindspore/ops/_op_impl/tbe/reduce_prod_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/reduce_std.py +0 -44
- mindspore/ops/_op_impl/tbe/reduce_sum.py +0 -39
- mindspore/ops/_op_impl/tbe/reduce_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6.py +0 -38
- mindspore/ops/_op_impl/tbe/relu6_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/relu6_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/relu6_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/relu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/relu_grad_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_grad_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/relu_v2.py +0 -40
- mindspore/ops/_op_impl/tbe/relu_v2_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/renorm.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_bilinear.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_bilinear_grad.py +0 -41
- mindspore/ops/_op_impl/tbe/resize_bilinear_v2.py +0 -43
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/reverse_v2_d.py +0 -37
- mindspore/ops/_op_impl/tbe/rint.py +0 -37
- mindspore/ops/_op_impl/tbe/rint_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/roi_align.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roi_align_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/roi_align_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/roll.py +0 -42
- mindspore/ops/_op_impl/tbe/round.py +0 -38
- mindspore/ops/_op_impl/tbe/round_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/rsqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/rsqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/rsqrt_grad.py +0 -40
- mindspore/ops/_op_impl/tbe/rsqrt_grad_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_add.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_div.py +0 -46
- mindspore/ops/_op_impl/tbe/scatter_max.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_min.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_mul.py +0 -44
- mindspore/ops/_op_impl/tbe/scatter_nd.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +0 -45
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +0 -41
- mindspore/ops/_op_impl/tbe/scatter_nd_ds.py +0 -49
- mindspore/ops/_op_impl/tbe/scatter_nd_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_nd_update.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_nd_update_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add.py +0 -39
- mindspore/ops/_op_impl/tbe/scatter_non_aliasing_add_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/scatter_sub.py +0 -47
- mindspore/ops/_op_impl/tbe/scatter_sub_ds.py +0 -48
- mindspore/ops/_op_impl/tbe/scatter_update.py +0 -43
- mindspore/ops/_op_impl/tbe/select.py +0 -38
- mindspore/ops/_op_impl/tbe/select_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/selu.py +0 -39
- mindspore/ops/_op_impl/tbe/selu_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sgd.py +0 -62
- mindspore/ops/_op_impl/tbe/sigmoid.py +0 -37
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad.py +0 -42
- mindspore/ops/_op_impl/tbe/sigmoid_cross_entropy_with_logits_grad_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/sigmoid_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sigmoid_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/sigmoid_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/sign.py +0 -38
- mindspore/ops/_op_impl/tbe/sign_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/sin.py +0 -37
- mindspore/ops/_op_impl/tbe/sin_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sinh.py +0 -37
- mindspore/ops/_op_impl/tbe/sinh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/slice.py +0 -58
- mindspore/ops/_op_impl/tbe/smooth_l1_loss.py +0 -45
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad.py +0 -46
- mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/soft_margin_loss.py +0 -38
- mindspore/ops/_op_impl/tbe/soft_margin_loss_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/soft_shrink.py +0 -36
- mindspore/ops/_op_impl/tbe/soft_shrink_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax.py +0 -37
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_cross_entropy_with_logits_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/softmax_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softmax_grad_ext.py +0 -42
- mindspore/ops/_op_impl/tbe/softmax_v2_with_dropout_do_mask_v3.py +0 -39
- mindspore/ops/_op_impl/tbe/softplus.py +0 -37
- mindspore/ops/_op_impl/tbe/softplus_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad.py +0 -38
- mindspore/ops/_op_impl/tbe/softplus_grad_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/softsign.py +0 -37
- mindspore/ops/_op_impl/tbe/softsign_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sort.py +0 -38
- mindspore/ops/_op_impl/tbe/sort_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/space_to_batch.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +0 -38
- mindspore/ops/_op_impl/tbe/space_to_depth.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_adadelta.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad.py +0 -45
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_ds.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py +0 -46
- mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2_ds.py +0 -47
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +0 -53
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py +0 -50
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +0 -66
- mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad_ds.py +0 -67
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop.py +0 -57
- mindspore/ops/_op_impl/tbe/sparse_apply_r_m_s_prop_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/sparse_gather_v2.py +0 -56
- mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py +0 -58
- mindspore/ops/_op_impl/tbe/split_d.py +0 -38
- mindspore/ops/_op_impl/tbe/split_d_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/split_v.py +0 -39
- mindspore/ops/_op_impl/tbe/splitv.py +0 -39
- mindspore/ops/_op_impl/tbe/sqrt.py +0 -37
- mindspore/ops/_op_impl/tbe/sqrt_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/sqrt_grad.py +0 -43
- mindspore/ops/_op_impl/tbe/sqrt_grad_ds.py +0 -44
- mindspore/ops/_op_impl/tbe/square.py +0 -38
- mindspore/ops/_op_impl/tbe/square_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_all.py +0 -40
- mindspore/ops/_op_impl/tbe/square_sum_all_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/square_sum_v1.py +0 -38
- mindspore/ops/_op_impl/tbe/square_sum_v1_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/square_sum_v2.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference.py +0 -39
- mindspore/ops/_op_impl/tbe/squared_difference_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/squeeze.py +0 -37
- mindspore/ops/_op_impl/tbe/strided_read.py +0 -38
- mindspore/ops/_op_impl/tbe/strided_slice_d.py +0 -44
- mindspore/ops/_op_impl/tbe/strided_slice_ds.py +0 -71
- mindspore/ops/_op_impl/tbe/strided_slice_grad_d.py +0 -51
- mindspore/ops/_op_impl/tbe/strided_slice_grad_ds.py +0 -57
- mindspore/ops/_op_impl/tbe/strided_write.py +0 -38
- mindspore/ops/_op_impl/tbe/sub.py +0 -39
- mindspore/ops/_op_impl/tbe/sub_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tan.py +0 -38
- mindspore/ops/_op_impl/tbe/tan_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh.py +0 -37
- mindspore/ops/_op_impl/tbe/tanh_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/tanh_grad.py +0 -39
- mindspore/ops/_op_impl/tbe/tanh_grad_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/tensor_move.py +0 -49
- mindspore/ops/_op_impl/tbe/tensor_move_ds.py +0 -50
- mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +0 -41
- mindspore/ops/_op_impl/tbe/tile.py +0 -37
- mindspore/ops/_op_impl/tbe/tile_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k.py +0 -42
- mindspore/ops/_op_impl/tbe/top_k_ds.py +0 -43
- mindspore/ops/_op_impl/tbe/trans_data.py +0 -167
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +0 -180
- mindspore/ops/_op_impl/tbe/trans_data_rnn.py +0 -44
- mindspore/ops/_op_impl/tbe/transpose.py +0 -60
- mindspore/ops/_op_impl/tbe/transpose_d.py +0 -47
- mindspore/ops/_op_impl/tbe/transpose_nod.py +0 -60
- mindspore/ops/_op_impl/tbe/trunc.py +0 -39
- mindspore/ops/_op_impl/tbe/truncate_div.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_div_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/truncate_mod.py +0 -41
- mindspore/ops/_op_impl/tbe/truncate_mod_ds.py +0 -42
- mindspore/ops/_op_impl/tbe/unpack.py +0 -38
- mindspore/ops/_op_impl/tbe/unpack_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/unsorted_segment_max.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_max_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_min.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_min_ds.py +0 -40
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod.py +0 -49
- mindspore/ops/_op_impl/tbe/unsorted_segment_prod_ds.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum.py +0 -38
- mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py +0 -41
- mindspore/ops/_op_impl/tbe/wts_arq.py +0 -40
- mindspore/ops/_op_impl/tbe/xdivy.py +0 -38
- mindspore/ops/_op_impl/tbe/xdivy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/xlogy.py +0 -38
- mindspore/ops/_op_impl/tbe/xlogy_ds.py +0 -39
- mindspore/ops/_op_impl/tbe/zeros_like.py +0 -41
- mindspore/ops/_op_impl/tbe/zeros_like_ds.py +0 -42
- mindspore/ops/_tracefunc.py +0 -241
- mindspore/ops/arg_dtype_cast.py +0 -54
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/rewrite/api/tree_node_helper.py +0 -60
- mindspore/rewrite/ast_helpers/ast_creator.py +0 -115
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +0 -267
- mindspore/rewrite/ast_transformers/remove_return_out_of_if.py +0 -228
- mindspore/rewrite/namespace.py +0 -53
- mindspore-2.2.14.dist-info/RECORD +0 -1924
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.2.14.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,7 @@
|
|
|
17
17
|
from types import FunctionType, MethodType
|
|
18
18
|
from collections.abc import Iterable
|
|
19
19
|
import os
|
|
20
|
+
import weakref
|
|
20
21
|
import numpy as np
|
|
21
22
|
|
|
22
23
|
from mindspore.common import Tensor
|
|
@@ -24,20 +25,21 @@ from mindspore.common._stub_tensor import StubTensor
|
|
|
24
25
|
from mindspore.ops import composite as C
|
|
25
26
|
from mindspore.ops.operations.array_ops import Cast
|
|
26
27
|
from mindspore.ops.operations._scalar_ops import bit_or, bit_and
|
|
27
|
-
from mindspore.ops.operations.comm_ops import ReduceOp
|
|
28
28
|
from mindspore.ops import signature as sig
|
|
29
29
|
from mindspore.ops.operations.math_ops import _infer_shape_reduce
|
|
30
30
|
from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
|
|
31
31
|
_run_op, _check_contains_variable
|
|
32
32
|
from mindspore._c_expression import Tensor as Tensor_
|
|
33
|
-
from mindspore._c_expression import typing
|
|
33
|
+
from mindspore._c_expression import typing, HookType
|
|
34
34
|
from mindspore import _checkparam as validator
|
|
35
35
|
from mindspore.common import dtype as mstype
|
|
36
36
|
from mindspore.common.parameter import Parameter
|
|
37
|
-
from mindspore.communication.management import GlobalComm, get_rank
|
|
37
|
+
from mindspore.communication.management import GlobalComm, get_rank, _get_group, get_group_size
|
|
38
38
|
from mindspore.common.api import _pynative_executor
|
|
39
39
|
from mindspore.common._register_for_adapter import ms_adapter_registry
|
|
40
40
|
from mindspore import ops
|
|
41
|
+
from ..auto_generate import TensorCopySlices, SiLU, Cummin, TopKRouter, ExtractImagePatches, DecoderKVCache, \
|
|
42
|
+
PromptKVCache, ApplyCamePart1, ApplyCamePart2, ApplyCamePart3, ApplyCamePart4
|
|
41
43
|
|
|
42
44
|
# Bit operation
|
|
43
45
|
bit_and = bit_and()
|
|
@@ -56,73 +58,28 @@ string_mul = Primitive("string_mul")
|
|
|
56
58
|
string_getitem = Primitive("string_getitem")
|
|
57
59
|
|
|
58
60
|
|
|
59
|
-
class
|
|
61
|
+
class Generator(Primitive):
|
|
60
62
|
r"""
|
|
61
|
-
|
|
62
|
-
The input tensor must be a 4-D tensor and the data format is NCHW.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
ksizes (Union[tuple[int], list[int]]): The size of sliding window, must be a tuple or a list of integers,
|
|
66
|
-
and the format is [1, 1, ksize_row, ksize_col].
|
|
67
|
-
strides (Union[tuple[int], list[int]]): Distance between the centers of the two consecutive patches,
|
|
68
|
-
must be a tuple or list of int, and the format is [1, 1, stride_row, stride_col].
|
|
69
|
-
rates (Union[tuple[int], list[int]]): In each extracted patch, the gap between the corresponding dimension
|
|
70
|
-
pixel positions, must be a tuple or a list of integers, and the format is [1, 1, rate_row, rate_col].
|
|
71
|
-
padding (str): The type of padding algorithm, is a string whose value is "same" or "valid",
|
|
72
|
-
not case sensitive. Default: "valid".
|
|
73
|
-
|
|
74
|
-
- same: Means that the patch can take the part beyond the original image, and this part is filled with 0.
|
|
75
|
-
|
|
76
|
-
- valid: Means that the taken patch area must be completely covered in the original image.
|
|
63
|
+
Manage the state of random number generation.
|
|
77
64
|
|
|
78
65
|
Inputs:
|
|
79
|
-
- **
|
|
66
|
+
- **cmd** (int) : operation to be executed.
|
|
67
|
+
- **inputs** (tuple[tensor]) : inputs for the operation.
|
|
80
68
|
|
|
81
69
|
Outputs:
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
.. math::
|
|
87
|
-
out_depth=ksize\_row * ksize\_col * in\_depth
|
|
88
|
-
|
|
89
|
-
and
|
|
90
|
-
if 'padding' is "valid":
|
|
91
|
-
|
|
92
|
-
.. math::
|
|
93
|
-
out\_row=floor((in\_row - (ksize\_row + (ksize\_row - 1) * (rate\_row - 1))) / stride\_row) + 1
|
|
94
|
-
out\_col=floor((in\_col - (ksize\_col + (ksize\_col - 1) * (rate\_col - 1))) / stride\_col) + 1
|
|
95
|
-
|
|
96
|
-
if 'padding' is "same":
|
|
97
|
-
|
|
98
|
-
.. math::
|
|
99
|
-
out\_row=floor((in\_row - 1) / stride\_row) + 1
|
|
100
|
-
out\_col=floor((in\_col - 1) / stride\_col) + 1
|
|
101
|
-
|
|
102
|
-
Supported Platforms:
|
|
103
|
-
``Ascend`` ``GPU``
|
|
70
|
+
- **seed** (Tensor): Seed for the random number generation algorithm.
|
|
71
|
+
- **offset** (Tensor): Offset of the random number sequence.
|
|
72
|
+
- **state** (Tensor): State tensor, can be used to restore current state.
|
|
104
73
|
"""
|
|
105
74
|
|
|
106
75
|
@prim_attr_register
|
|
107
|
-
def __init__(self
|
|
108
|
-
""
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
f"{arg_name}_col, 1], but got {arg_val}.")
|
|
115
|
-
if not isinstance(arg_val[2], int) or not isinstance(arg_val[3], int) or arg_val[2] < 1 or arg_val[3] < 1:
|
|
116
|
-
raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s must be "
|
|
117
|
-
f"an positive integer number, but got {arg_name}_row is {arg_val[2]}, "
|
|
118
|
-
f"{arg_name}_col is {arg_val[3]}")
|
|
119
|
-
|
|
120
|
-
_check_tuple_or_list("ksize", ksizes, self.name)
|
|
121
|
-
_check_tuple_or_list("stride", strides, self.name)
|
|
122
|
-
_check_tuple_or_list("rate", rates, self.name)
|
|
123
|
-
validator.check_value_type('padding', padding, [str], self.name)
|
|
124
|
-
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
|
|
125
|
-
self.add_prim_attr("padding", self.padding)
|
|
76
|
+
def __init__(self):
|
|
77
|
+
self.add_prim_attr("side_effect_mem", True)
|
|
78
|
+
|
|
79
|
+
def __call__(self, cmd, inputs):
|
|
80
|
+
if cmd == 0: # step cmd
|
|
81
|
+
return inputs[0], inputs[1]
|
|
82
|
+
return super().__call__(cmd, inputs)
|
|
126
83
|
|
|
127
84
|
|
|
128
85
|
class Quant(PrimitiveWithInfer):
|
|
@@ -140,7 +97,7 @@ class Quant(PrimitiveWithInfer):
|
|
|
140
97
|
y = round(scale * x * scale + offset)
|
|
141
98
|
|
|
142
99
|
Note:
|
|
143
|
-
This operation only support
|
|
100
|
+
This operation only support Atlas 200/300/500 inference product.
|
|
144
101
|
|
|
145
102
|
Args:
|
|
146
103
|
scale (float) : Specifies the scaling ratio.
|
|
@@ -253,7 +210,7 @@ class Dequant(PrimitiveWithInfer):
|
|
|
253
210
|
y = x * deq\_scale * deq\_scale
|
|
254
211
|
|
|
255
212
|
Note:
|
|
256
|
-
This operation only support
|
|
213
|
+
This operation only support Atlas 200/300/500 inference product.
|
|
257
214
|
|
|
258
215
|
Args:
|
|
259
216
|
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
|
|
@@ -274,10 +231,10 @@ class Dequant(PrimitiveWithInfer):
|
|
|
274
231
|
"""
|
|
275
232
|
|
|
276
233
|
@prim_attr_register
|
|
277
|
-
def __init__(self, sqrt_mode=False, relu_flag=False):
|
|
234
|
+
def __init__(self, sqrt_mode=False, relu_flag=False, dtype=mstype.float16):
|
|
278
235
|
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
|
279
236
|
self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name)
|
|
280
|
-
self.
|
|
237
|
+
self.dtype = dtype
|
|
281
238
|
|
|
282
239
|
def infer_shape(self, x_shape, deq_scale_shape):
|
|
283
240
|
return x_shape
|
|
@@ -289,6 +246,53 @@ class Dequant(PrimitiveWithInfer):
|
|
|
289
246
|
return mstype.float16
|
|
290
247
|
|
|
291
248
|
|
|
249
|
+
class AntiQuant(Primitive):
|
|
250
|
+
r"""
|
|
251
|
+
Returns the antiquantized value of input_x.
|
|
252
|
+
|
|
253
|
+
If `sqrt_mode` is False:
|
|
254
|
+
|
|
255
|
+
.. math::
|
|
256
|
+
y = scale * (x + offset)
|
|
257
|
+
|
|
258
|
+
If `sqrt_mode` is True:
|
|
259
|
+
|
|
260
|
+
.. math::
|
|
261
|
+
y = scale * scale * (x + offset)
|
|
262
|
+
|
|
263
|
+
Note:
|
|
264
|
+
This operation only support Atlas 200/300/500 inference product.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
scale (float) : Specifies the scaling ratio.
|
|
268
|
+
offset (float): Specifies the offset.
|
|
269
|
+
sqrt_mode (bool) : Specifies whether to perform square root on `scale`. Default: ``False``.
|
|
270
|
+
|
|
271
|
+
Inputs:
|
|
272
|
+
- **input_x** (Tensor) : Input tensor. Must be mindspore.int8.
|
|
273
|
+
|
|
274
|
+
Outputs:
|
|
275
|
+
- Tensor: The antiquantized output tensor of type mindspore.float32.
|
|
276
|
+
|
|
277
|
+
Examples:
|
|
278
|
+
>>> from mindspore.ops.operations._inner_ops import AntiQuant
|
|
279
|
+
>>> input_x = Tensor([50.0, 20.0], mstype.int8)
|
|
280
|
+
>>> antiquant = AntiQuant(2.0, 1.0, False)
|
|
281
|
+
>>> y = antiquant(input_x)
|
|
282
|
+
>>> print(y)
|
|
283
|
+
[102. 42.]
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
@prim_attr_register
|
|
287
|
+
def __init__(self, sqrt_mode=False, dtype=mstype.float16):
|
|
288
|
+
super().__init__("AntiQuant")
|
|
289
|
+
self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
|
|
290
|
+
self.dtype = dtype
|
|
291
|
+
|
|
292
|
+
self.init_prim_io_names(inputs=['x', 'scale', 'offset'],
|
|
293
|
+
outputs=['y'])
|
|
294
|
+
|
|
295
|
+
|
|
292
296
|
class MatrixDiag(PrimitiveWithInfer):
|
|
293
297
|
"""
|
|
294
298
|
Returns a batched diagonal tensor with a given batched diagonal values.
|
|
@@ -386,227 +390,6 @@ class MatrixDiagPart(PrimitiveWithInfer):
|
|
|
386
390
|
return out_shape
|
|
387
391
|
|
|
388
392
|
|
|
389
|
-
class Send(PrimitiveWithInfer):
|
|
390
|
-
"""
|
|
391
|
-
Send tensors from src_rank to the specified dest_rank.
|
|
392
|
-
|
|
393
|
-
Note:
|
|
394
|
-
Send and Receive must be used in combination and have same sr_tag.
|
|
395
|
-
Send must be used between servers.
|
|
396
|
-
|
|
397
|
-
Args:
|
|
398
|
-
sr_tag (int): A required integer identifying the send/recv message tag. The message will
|
|
399
|
-
will be received by the Receive op with the same "sr_tag".
|
|
400
|
-
dest_rank (int): A required integer identifying the destination rank.
|
|
401
|
-
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
|
402
|
-
|
|
403
|
-
Inputs:
|
|
404
|
-
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
405
|
-
|
|
406
|
-
Examples:
|
|
407
|
-
>>> import mindspore.ops as ops
|
|
408
|
-
>>> import mindspore.nn as nn
|
|
409
|
-
>>> from mindspore.communication import init
|
|
410
|
-
>>> from mindspore import Tensor
|
|
411
|
-
>>> import numpy as np
|
|
412
|
-
>>>
|
|
413
|
-
>>> init()
|
|
414
|
-
>>> class Net(nn.Cell):
|
|
415
|
-
>>> def __init__(self):
|
|
416
|
-
>>> super(Net, self).__init__()
|
|
417
|
-
>>> self.depend = ops.Depend()
|
|
418
|
-
>>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
|
|
419
|
-
>>>
|
|
420
|
-
>>> def construct(self, x):
|
|
421
|
-
>>> out = self.depend(x, self.send(x))
|
|
422
|
-
>>> return out
|
|
423
|
-
>>>
|
|
424
|
-
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
425
|
-
>>> net = Net()
|
|
426
|
-
>>> output = net(input_)
|
|
427
|
-
"""
|
|
428
|
-
|
|
429
|
-
@prim_attr_register
|
|
430
|
-
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP):
|
|
431
|
-
self.rank = dest_rank
|
|
432
|
-
self.sr_tag = sr_tag
|
|
433
|
-
self.group = group
|
|
434
|
-
self.add_prim_attr("no_eliminate", True)
|
|
435
|
-
|
|
436
|
-
def infer_shape(self, x_shape):
|
|
437
|
-
self.add_prim_attr("shape", x_shape)
|
|
438
|
-
return x_shape
|
|
439
|
-
|
|
440
|
-
def infer_dtype(self, x_dtype):
|
|
441
|
-
return x_dtype
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
class Receive(PrimitiveWithInfer):
|
|
445
|
-
"""
|
|
446
|
-
Receive tensors from src_rank.
|
|
447
|
-
|
|
448
|
-
Note:
|
|
449
|
-
Send and Receive must be used in combination and have same sr_tag.
|
|
450
|
-
Receive must be used between servers.
|
|
451
|
-
|
|
452
|
-
Args:
|
|
453
|
-
sr_tag (int): A required integer identifying the send/recv message tag. The message will
|
|
454
|
-
will be send by the Send op with the same "sr_tag".
|
|
455
|
-
src_rank (int): A required integer identifying the source rank.
|
|
456
|
-
shape (list[int]): A required list identifying the shape of the tensor to be received.
|
|
457
|
-
dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
|
|
458
|
-
int8, int16, int32, float16, float32.
|
|
459
|
-
group (str, optional): The communication group to work on.
|
|
460
|
-
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
461
|
-
|
|
462
|
-
Inputs:
|
|
463
|
-
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
464
|
-
|
|
465
|
-
Examples:
|
|
466
|
-
>>> import mindspore.ops as ops
|
|
467
|
-
>>> import mindspore.nn as nn
|
|
468
|
-
>>> from mindspore.communication import init
|
|
469
|
-
>>> from mindspore import Tensor
|
|
470
|
-
>>> import numpy as np
|
|
471
|
-
>>>
|
|
472
|
-
>>> init()
|
|
473
|
-
>>> class Net(nn.Cell):
|
|
474
|
-
>>> def __init__(self):
|
|
475
|
-
>>> super(Net, self).__init__()
|
|
476
|
-
>>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
|
|
477
|
-
>>> group="hccl_world_group")
|
|
478
|
-
>>>
|
|
479
|
-
>>> def construct(self):
|
|
480
|
-
>>> out = self.recv()
|
|
481
|
-
>>> return out
|
|
482
|
-
>>>
|
|
483
|
-
>>> net = Net()
|
|
484
|
-
>>> output = net()
|
|
485
|
-
"""
|
|
486
|
-
|
|
487
|
-
@prim_attr_register
|
|
488
|
-
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP,
|
|
489
|
-
group_back=GlobalComm.WORLD_COMM_GROUP):
|
|
490
|
-
self.rank = src_rank
|
|
491
|
-
self.tag = sr_tag
|
|
492
|
-
self.shape = shape
|
|
493
|
-
self.dtype = dtype
|
|
494
|
-
self.group = group
|
|
495
|
-
self.add_prim_attr("no_eliminate", True)
|
|
496
|
-
valid_type = [mstype.float16, mstype.bfloat16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
|
497
|
-
args = {"dtype": dtype}
|
|
498
|
-
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
|
|
499
|
-
|
|
500
|
-
def infer_shape(self, x_shape=None):
|
|
501
|
-
return self.get_attr_dict()['shape']
|
|
502
|
-
|
|
503
|
-
def infer_dtype(self, x_dtype=None):
|
|
504
|
-
return self.get_attr_dict()['dtype']
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
class Reduce(PrimitiveWithInfer):
|
|
508
|
-
"""
|
|
509
|
-
Reduces tensor across the processes in the specified communication group.
|
|
510
|
-
|
|
511
|
-
Note:
|
|
512
|
-
Only process with destination rank receives the reduced output.
|
|
513
|
-
Other processes only get a tensor with shape [1], which has no mathematical meaning.
|
|
514
|
-
|
|
515
|
-
Args:
|
|
516
|
-
dest_rank (int): Specifies the rank of the process that receives the reduced output.
|
|
517
|
-
op (str, optional): Specifies an operation used for element-wise reductions, like sum, prod, max, and min.
|
|
518
|
-
On the CPU, only 'sum' is supported. Default: ``ReduceOp.SUM`` .
|
|
519
|
-
group (str, optional): The communication group to work on.
|
|
520
|
-
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
521
|
-
|
|
522
|
-
Inputs:
|
|
523
|
-
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
|
524
|
-
|
|
525
|
-
Examples:
|
|
526
|
-
>>> import mindspore.ops as ops
|
|
527
|
-
>>> import mindspore.nn as nn
|
|
528
|
-
>>> from mindspore.communication import init
|
|
529
|
-
>>> from mindspore import Tensor
|
|
530
|
-
>>> import numpy as np
|
|
531
|
-
>>> # Launch 4 processes.
|
|
532
|
-
>>> init()
|
|
533
|
-
>>> class ReduceNet(nn.Cell):
|
|
534
|
-
>>> def __init__(self):
|
|
535
|
-
>>> super(Net, self).__init__()
|
|
536
|
-
>>> self.reduce = ops.Reduce(dest_rank=1)
|
|
537
|
-
>>>
|
|
538
|
-
>>> def construct(self, x):
|
|
539
|
-
>>> out = self.reduce(x)
|
|
540
|
-
>>> return out
|
|
541
|
-
>>> input = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
542
|
-
>>> net = ReduceNet()
|
|
543
|
-
>>> output = net(input)
|
|
544
|
-
>>> print(output)
|
|
545
|
-
Process with rank 1: [[4. 4. 4. 4. 4. 4. 4. 4.]
|
|
546
|
-
[4. 4. 4. 4. 4. 4. 4. 4.]],
|
|
547
|
-
Other proesses: [0.].
|
|
548
|
-
"""
|
|
549
|
-
|
|
550
|
-
@prim_attr_register
|
|
551
|
-
def __init__(self, dest_rank, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
|
552
|
-
self.dest_rank = dest_rank
|
|
553
|
-
self.op = op
|
|
554
|
-
self.group = group
|
|
555
|
-
|
|
556
|
-
def infer_shape(self, x_shape):
|
|
557
|
-
# The process with dest_rank returns the reduced output.
|
|
558
|
-
# Other processes only gets a tensor with shape [1], which has no mathematical meaning.
|
|
559
|
-
if self.dest_rank == get_rank():
|
|
560
|
-
return x_shape
|
|
561
|
-
return [1]
|
|
562
|
-
|
|
563
|
-
def infer_dtype(self, x_dtype):
|
|
564
|
-
return x_dtype
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
class Barrier(PrimitiveWithInfer):
|
|
568
|
-
"""
|
|
569
|
-
Synchronizes all processes in the specified group.
|
|
570
|
-
|
|
571
|
-
Note:
|
|
572
|
-
After calling this collective operator,
|
|
573
|
-
this process will be blocked until all other processes in the group call this operator.
|
|
574
|
-
|
|
575
|
-
Args:
|
|
576
|
-
group (str, optional): The communication group to work on.
|
|
577
|
-
Default: "hccl_world_group" on Ascend, "nccl_world_group" on GPU.
|
|
578
|
-
|
|
579
|
-
Examples:
|
|
580
|
-
>>> import mindspore.ops as ops
|
|
581
|
-
>>> import mindspore.nn as nn
|
|
582
|
-
>>> from mindspore.communication import init
|
|
583
|
-
>>> from mindspore import Tensor
|
|
584
|
-
>>> import numpy as np
|
|
585
|
-
>>> # Launch 4 processes.
|
|
586
|
-
>>> init()
|
|
587
|
-
>>> class BarrierNet(nn.Cell):
|
|
588
|
-
>>> def __init__(self):
|
|
589
|
-
>>> super(Net, self).__init__()
|
|
590
|
-
>>> self.barrier = ops.Barrier()
|
|
591
|
-
>>>
|
|
592
|
-
>>> def construct(self):
|
|
593
|
-
>>> self.barrier()
|
|
594
|
-
>>> net = BarrierNet()
|
|
595
|
-
>>> net()
|
|
596
|
-
"""
|
|
597
|
-
|
|
598
|
-
@prim_attr_register
|
|
599
|
-
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
|
600
|
-
self.group = group
|
|
601
|
-
self.add_prim_attr("side_effect_mem", True)
|
|
602
|
-
|
|
603
|
-
def infer_shape(self):
|
|
604
|
-
return [1]
|
|
605
|
-
|
|
606
|
-
def infer_dtype(self):
|
|
607
|
-
return mstype.float32
|
|
608
|
-
|
|
609
|
-
|
|
610
393
|
class MatrixSetDiag(PrimitiveWithInfer):
|
|
611
394
|
r"""
|
|
612
395
|
Modifies the batched diagonal part of a batched tensor.
|
|
@@ -1329,45 +1112,6 @@ class DynamicBroadcastGradientArgs(Primitive):
|
|
|
1329
1112
|
"""Init BroadcastGradientArgs"""
|
|
1330
1113
|
|
|
1331
1114
|
|
|
1332
|
-
class TensorCopySlices(Primitive):
|
|
1333
|
-
"""
|
|
1334
|
-
Copy continues memory.
|
|
1335
|
-
|
|
1336
|
-
Inputs:
|
|
1337
|
-
- **x** (Tensor) - The target Tensor.
|
|
1338
|
-
- **value** (Tensor) - The tensor to update x.
|
|
1339
|
-
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
|
|
1340
|
-
constant value is allowed.
|
|
1341
|
-
- **end** (tuple[int]) - A tuple or which represents the maximum location where to end.
|
|
1342
|
-
Only constant value is allowed.
|
|
1343
|
-
- **strides** (tuple[int]) - A tuple which represents the stride is continuously added
|
|
1344
|
-
before reaching the maximum location. Only constant value is allowed.
|
|
1345
|
-
|
|
1346
|
-
Outputs:
|
|
1347
|
-
- **y** (Tensor), has the same shape and data type of x.
|
|
1348
|
-
|
|
1349
|
-
Examples:
|
|
1350
|
-
>>> import numpy as np
|
|
1351
|
-
>>> from mindspore.ops.operations import _inner_ops
|
|
1352
|
-
>>> copy_slices = _inner_ops.TensorCopySlices()
|
|
1353
|
-
>>> out = copy_slices(Tensor(np.zeros((5, 5))), Tensor(np.ones((2, 5))), (3, 0), (5, 5), (1, 1))
|
|
1354
|
-
>>> print(out)
|
|
1355
|
-
[[1., 1., 1., 1., 1.],
|
|
1356
|
-
[1., 1., 1., 1., 1.],
|
|
1357
|
-
[1., 1., 1., 1., 1.],
|
|
1358
|
-
[0., 0., 0., 0., 0.],
|
|
1359
|
-
[0., 0., 0., 0., 0.]]
|
|
1360
|
-
|
|
1361
|
-
Supported Platforms:
|
|
1362
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
1363
|
-
"""
|
|
1364
|
-
|
|
1365
|
-
@prim_attr_register
|
|
1366
|
-
def __init__(self):
|
|
1367
|
-
"""Initialize TensorScatterUpdate"""
|
|
1368
|
-
self.init_prim_io_names(inputs=['x', 'value', 'begin', 'end', 'strides'], outputs=['y'])
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
1115
|
class DSDMatmul(PrimitiveWithInfer):
|
|
1372
1116
|
"""
|
|
1373
1117
|
The definition of the CusSquare primitive.
|
|
@@ -1592,46 +1336,6 @@ class DynamicBroadcastTo(Primitive):
|
|
|
1592
1336
|
self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y'])
|
|
1593
1337
|
|
|
1594
1338
|
|
|
1595
|
-
class Cummin(Primitive):
|
|
1596
|
-
r"""
|
|
1597
|
-
Returns the cumulative minimum of elements and the index.
|
|
1598
|
-
|
|
1599
|
-
.. warning::
|
|
1600
|
-
This is an experimental API that is subject to change or deletion.
|
|
1601
|
-
|
|
1602
|
-
Refer to :func:`mindspore.ops.cummin` for more detail.
|
|
1603
|
-
|
|
1604
|
-
Args:
|
|
1605
|
-
axis (int): The axis to accumulate the tensor's value. Must be in the range [-rank(input), rank(input)).
|
|
1606
|
-
|
|
1607
|
-
Inputs:
|
|
1608
|
-
- **input** (Tensor) - The input tensor.
|
|
1609
|
-
|
|
1610
|
-
Outputs:
|
|
1611
|
-
A tuple of 2 Tensors(values, indices), containing the cumulative minimum of elements and the index,
|
|
1612
|
-
The shape of each output tensor is the same as input `input`.
|
|
1613
|
-
|
|
1614
|
-
Supported Platforms:
|
|
1615
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
1616
|
-
|
|
1617
|
-
Examples:
|
|
1618
|
-
>>> from mindspore import Tensor, ops
|
|
1619
|
-
>>> import mindspore
|
|
1620
|
-
>>> a = Tensor([-0.2284, -0.6628, 0.0975, 0.2680, -1.3298, -0.4220], mindspore.float32)
|
|
1621
|
-
>>> func = ops.Cummin(axis=0)
|
|
1622
|
-
>>> output = func(a)
|
|
1623
|
-
>>> print(output[0])
|
|
1624
|
-
[-0.2284 -0.6628 -0.6628 -0.6628 -1.3298 -1.3298]
|
|
1625
|
-
>>> print(output[1])
|
|
1626
|
-
[0 1 1 1 4 4]
|
|
1627
|
-
"""
|
|
1628
|
-
|
|
1629
|
-
@prim_attr_register
|
|
1630
|
-
def __init__(self, axis):
|
|
1631
|
-
"""Initialize Cummin"""
|
|
1632
|
-
validator.check_value_type('axis', axis, [int], self.name)
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
1339
|
class DynamicResizeNearestNeighbor(Primitive):
|
|
1636
1340
|
r"""
|
|
1637
1341
|
Resizes the input tensor by using the nearest neighbor algorithm.
|
|
@@ -1832,7 +1536,7 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1832
1536
|
... print(grad)
|
|
1833
1537
|
...
|
|
1834
1538
|
>>> hook = inner.CellBackwardHook()
|
|
1835
|
-
>>> hook_fn_key = hook.register_backward_hook(
|
|
1539
|
+
>>> hook_fn_key = hook.register_backward_hook()
|
|
1836
1540
|
>>> def hook_test(x, y):
|
|
1837
1541
|
... z = x * y
|
|
1838
1542
|
... z = hook(z)
|
|
@@ -1853,16 +1557,19 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1853
1557
|
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
|
|
1854
1558
|
"""
|
|
1855
1559
|
|
|
1856
|
-
def __init__(self, cell_id=""):
|
|
1560
|
+
def __init__(self, cell_id="", cell=None, hook_dict=None):
|
|
1857
1561
|
"""Initialize CellBackwardHook"""
|
|
1858
1562
|
super(CellBackwardHook, self).__init__(self.__class__.__name__)
|
|
1859
1563
|
self.cell_id = cell_id
|
|
1564
|
+
self.cell = cell
|
|
1565
|
+
self.hook_dict = weakref.ref(hook_dict)
|
|
1860
1566
|
self.add_prim_attr("cell_id", cell_id)
|
|
1861
|
-
self.
|
|
1567
|
+
self.grad_output = None
|
|
1862
1568
|
|
|
1863
|
-
def __call__(self, args):
|
|
1864
|
-
|
|
1865
|
-
|
|
1569
|
+
def __call__(self, *args):
|
|
1570
|
+
# If args is empty, just return.
|
|
1571
|
+
if not args:
|
|
1572
|
+
return args
|
|
1866
1573
|
return _run_op(self, self.name, args)
|
|
1867
1574
|
|
|
1868
1575
|
def infer_shape(self, *inputs_shape):
|
|
@@ -1875,51 +1582,76 @@ class CellBackwardHook(PrimitiveWithInfer):
|
|
|
1875
1582
|
return inputs_type[0]
|
|
1876
1583
|
return inputs_type
|
|
1877
1584
|
|
|
1878
|
-
def register_backward_hook(self
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
mode.
|
|
1882
|
-
|
|
1883
|
-
Note:
|
|
1884
|
-
The 'hook_fn' must be defined as the following code.
|
|
1885
|
-
`cell_id` is the information of registered cell. `grad_input` is the gradient passed to the cell.
|
|
1886
|
-
`grad_output` is the gradient computed and passed to the next cell or primitive, which may be modified by
|
|
1887
|
-
returning a new output gradient.
|
|
1888
|
-
The 'hook_fn' should have the following signature:
|
|
1889
|
-
hook_fn(cell_id, grad_input, grad_output) -> New output gradient or none.
|
|
1890
|
-
The 'hook_fn' is executed in the python environment.
|
|
1585
|
+
def register_backward_hook(self):
|
|
1586
|
+
"""
|
|
1587
|
+
Register the backward hook function.
|
|
1891
1588
|
|
|
1892
1589
|
Args:
|
|
1893
|
-
|
|
1590
|
+
None
|
|
1894
1591
|
|
|
1895
1592
|
Returns:
|
|
1896
|
-
|
|
1593
|
+
None
|
|
1897
1594
|
|
|
1898
|
-
|
|
1899
|
-
|
|
1595
|
+
Supported Platforms:
|
|
1596
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1900
1597
|
"""
|
|
1901
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
1902
|
-
raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
1903
|
-
f"function, but got {type(hook_fn)}.")
|
|
1904
|
-
key = self.add_backward_hook_fn(hook_fn)
|
|
1905
|
-
return key
|
|
1906
1598
|
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
|
|
1599
|
+
def hook_backward_grad(grad):
|
|
1600
|
+
if self.grad_output is None:
|
|
1601
|
+
self.grad_output = grad
|
|
1602
|
+
# Indicates the first time of call backward hook, and need to wait for the second time call
|
|
1603
|
+
return self.cell_id
|
|
1604
|
+
backward_hook_grad_input = grad
|
|
1605
|
+
if self.hook_dict():
|
|
1606
|
+
backward_hooks = self.hook_dict().values()
|
|
1607
|
+
for hook in backward_hooks:
|
|
1608
|
+
res = hook(self.cell, backward_hook_grad_input, self.grad_output)
|
|
1609
|
+
if res is None:
|
|
1610
|
+
continue
|
|
1611
|
+
if not isinstance(res, tuple):
|
|
1612
|
+
res = (res,)
|
|
1613
|
+
if len(res) != len(grad):
|
|
1614
|
+
raise TypeError(
|
|
1615
|
+
"The backward hook return value size is {} not equal to expect grad input size {}".format(
|
|
1616
|
+
len(res), len(grad)))
|
|
1617
|
+
backward_hook_grad_input = res
|
|
1618
|
+
self.grad_output = None
|
|
1619
|
+
return backward_hook_grad_input
|
|
1620
|
+
|
|
1621
|
+
self.set_hook_fn(hook_backward_grad, HookType.BackwardHook)
|
|
1622
|
+
|
|
1623
|
+
def register_backward_pre_hook(self):
|
|
1624
|
+
"""
|
|
1625
|
+
Register the backward pre hook function.
|
|
1915
1626
|
|
|
1916
1627
|
Args:
|
|
1917
|
-
|
|
1628
|
+
None
|
|
1918
1629
|
|
|
1919
1630
|
Returns:
|
|
1920
|
-
None
|
|
1631
|
+
None
|
|
1632
|
+
|
|
1633
|
+
Supported Platforms:
|
|
1634
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
1921
1635
|
"""
|
|
1922
|
-
|
|
1636
|
+
|
|
1637
|
+
def hook_backward_pre_grad(grad):
|
|
1638
|
+
backward_pre_hook_grad = grad
|
|
1639
|
+
if self.hook_dict():
|
|
1640
|
+
backward_pre_hooks = self.hook_dict().values()
|
|
1641
|
+
for hook in backward_pre_hooks:
|
|
1642
|
+
res = hook(self.cell, backward_pre_hook_grad)
|
|
1643
|
+
if res is None:
|
|
1644
|
+
continue
|
|
1645
|
+
if not isinstance(res, tuple):
|
|
1646
|
+
res = (res,)
|
|
1647
|
+
if len(res) != len(grad):
|
|
1648
|
+
raise TypeError(
|
|
1649
|
+
"The backward pre hook return value size is {} not equal to expect output size {}".format(
|
|
1650
|
+
len(res), len(grad)))
|
|
1651
|
+
backward_pre_hook_grad = res
|
|
1652
|
+
return backward_pre_hook_grad
|
|
1653
|
+
|
|
1654
|
+
self.set_hook_fn(hook_backward_pre_grad, HookType.BackwardPreHook)
|
|
1923
1655
|
|
|
1924
1656
|
|
|
1925
1657
|
class Format(PrimitiveWithInfer):
|
|
@@ -1948,7 +1680,6 @@ class Format(PrimitiveWithInfer):
|
|
|
1948
1680
|
def __init__(self):
|
|
1949
1681
|
self.init_prim_io_names(inputs=['string', 'args'], outputs=['string'])
|
|
1950
1682
|
|
|
1951
|
-
|
|
1952
1683
|
def __infer__(self, str_, *var):
|
|
1953
1684
|
def check_variable(str_, var):
|
|
1954
1685
|
if _check_contains_variable(str_['dtype'], str_['value']):
|
|
@@ -1959,11 +1690,9 @@ class Format(PrimitiveWithInfer):
|
|
|
1959
1690
|
return True
|
|
1960
1691
|
return False
|
|
1961
1692
|
|
|
1962
|
-
|
|
1963
1693
|
if check_variable(str_, var):
|
|
1964
1694
|
return {'dtype': mstype.string, 'shape': [], 'value': None}
|
|
1965
1695
|
|
|
1966
|
-
|
|
1967
1696
|
str_value = str_['value']
|
|
1968
1697
|
kwargs = dict()
|
|
1969
1698
|
var_value = list()
|
|
@@ -2148,14 +1877,13 @@ class ClipByNorm(PrimitiveWithInfer):
|
|
|
2148
1877
|
@prim_attr_register
|
|
2149
1878
|
def __init__(self, axis=None):
|
|
2150
1879
|
"""Initialize ClipByNorm"""
|
|
2151
|
-
self.axis_str = 'axis'
|
|
2152
1880
|
self.axis = () if axis is None else axis
|
|
2153
|
-
validator.check_value_type(
|
|
1881
|
+
validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
|
|
2154
1882
|
axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
|
|
2155
1883
|
for i, value in enumerate(axis_check):
|
|
2156
1884
|
validator.check_value_type('axis[%d]' % i, value, [int], self.name)
|
|
2157
|
-
self.init_attrs[
|
|
2158
|
-
self.add_prim_attr(
|
|
1885
|
+
self.init_attrs['axis'] = self.axis
|
|
1886
|
+
self.add_prim_attr('axis', self.axis)
|
|
2159
1887
|
self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
|
|
2160
1888
|
|
|
2161
1889
|
def infer_shape(self, x_shape, clip_norm_shape):
|
|
@@ -2248,7 +1976,8 @@ class MixedPrecisionCast(Primitive):
|
|
|
2248
1976
|
|
|
2249
1977
|
def __call__(self, dst_dtype, x):
|
|
2250
1978
|
def cast_inner(data):
|
|
2251
|
-
if isinstance(data, Tensor) and data.dtype in (mstype.float16, mstype.float32,
|
|
1979
|
+
if isinstance(data, Tensor) and data.dtype in (mstype.float16, mstype.float32,
|
|
1980
|
+
mstype.float64, mstype.bfloat16):
|
|
2252
1981
|
return self.cast(data, dst_dtype)
|
|
2253
1982
|
return data
|
|
2254
1983
|
|
|
@@ -2559,7 +2288,7 @@ class ConvertToMsTensor(Primitive):
|
|
|
2559
2288
|
"""Run in PyNative mode"""
|
|
2560
2289
|
if isinstance(x, StubTensor):
|
|
2561
2290
|
return StubTensor(stub=x.stub, tensor=x.tensor)
|
|
2562
|
-
return ops.deepcopy(x)
|
|
2291
|
+
return ops.auto_generate.deepcopy(x)
|
|
2563
2292
|
|
|
2564
2293
|
|
|
2565
2294
|
convert_to_ms_tensor = ConvertToMsTensor()
|
|
@@ -2621,28 +2350,6 @@ class IsParameter(PrimitiveWithInfer):
|
|
|
2621
2350
|
'value': isinstance(x['dtype'], mstype.RefType)}
|
|
2622
2351
|
|
|
2623
2352
|
|
|
2624
|
-
class SiLU(Primitive):
|
|
2625
|
-
r"""
|
|
2626
|
-
Computes SiLU (Sigmoid Linear Unit activation function) of input tensors element-wise.
|
|
2627
|
-
|
|
2628
|
-
Refer to :func:`mindspore.ops.silu` for more details.
|
|
2629
|
-
|
|
2630
|
-
Supported Platforms:
|
|
2631
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
2632
|
-
|
|
2633
|
-
Examples:
|
|
2634
|
-
>>> x = Tensor(np.array([-1, 2, -3, 2, -1]), mindspore.float16)
|
|
2635
|
-
>>> output = ops.SiLU(x)
|
|
2636
|
-
>>> print(output)
|
|
2637
|
-
[-0.269 1.762 -0.1423 1.762 -0.269]
|
|
2638
|
-
"""
|
|
2639
|
-
|
|
2640
|
-
@prim_attr_register
|
|
2641
|
-
def __init__(self):
|
|
2642
|
-
"""Initialize SiLU"""
|
|
2643
|
-
self.init_prim_io_names(inputs=['x'], outputs=['output'])
|
|
2644
|
-
|
|
2645
|
-
|
|
2646
2353
|
class TileSize(Primitive):
|
|
2647
2354
|
r"""
|
|
2648
2355
|
Tile size for matmul
|
|
@@ -2726,6 +2433,7 @@ class CopyWithSlice(Primitive):
|
|
|
2726
2433
|
r"""
|
|
2727
2434
|
Copy data to discontinuous tensor
|
|
2728
2435
|
"""
|
|
2436
|
+
|
|
2729
2437
|
@prim_attr_register
|
|
2730
2438
|
def __init__(self):
|
|
2731
2439
|
self.add_prim_attr('side_effect_mem', True)
|
|
@@ -2775,10 +2483,10 @@ class FFN(Primitive):
|
|
|
2775
2483
|
>>> h = 1024
|
|
2776
2484
|
>>> h_f = 4 * h
|
|
2777
2485
|
>>> e = 16
|
|
2778
|
-
>>> x = Tensor(np.random.randn(
|
|
2486
|
+
>>> x = Tensor(np.random.randn(s, h).astype(np.float16))
|
|
2779
2487
|
>>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
|
|
2780
2488
|
>>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
|
|
2781
|
-
>>> expert_tokens = Tensor(np.
|
|
2489
|
+
>>> expert_tokens = Tensor(np.full(e, 8))
|
|
2782
2490
|
>>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
|
|
2783
2491
|
>>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
|
|
2784
2492
|
>>> ffn = _inner_ops.FFN("fastgelu", 1)
|
|
@@ -2790,189 +2498,47 @@ class FFN(Primitive):
|
|
|
2790
2498
|
def __init__(self, activation, inner_precise):
|
|
2791
2499
|
"""Initialize FFN."""
|
|
2792
2500
|
self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
|
|
2793
|
-
"bias2", "scale", "offset", "deq_scale1", "deq_scale2"
|
|
2501
|
+
"bias2", "scale", "offset", "deq_scale1", "deq_scale2",
|
|
2502
|
+
"antiquant_scale1", "antiquant_scale2",
|
|
2503
|
+
"antiquant_offset1", "antiquant_offset2"],
|
|
2794
2504
|
outputs=["y"])
|
|
2795
2505
|
cls_name = self.name
|
|
2796
2506
|
validator.check_value_type("activation", activation, [str], cls_name)
|
|
2797
2507
|
validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
|
|
2798
2508
|
|
|
2799
2509
|
|
|
2800
|
-
class
|
|
2801
|
-
r"""
|
|
2802
|
-
The DecoderKVCache is used for decoding the KVCache of transformer network.
|
|
2803
|
-
|
|
2804
|
-
Args:
|
|
2805
|
-
cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
|
|
2806
|
-
When seq_len_axis is 2, cache tensor of shape
|
|
2807
|
-
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
|
|
2808
|
-
When seq_len_axis is 1, cache tensor of shape
|
|
2809
|
-
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
|
|
2810
|
-
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
|
|
2811
|
-
When seq_len_axis is 2, update tensor of shape
|
|
2812
|
-
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
|
|
2813
|
-
When seq_len_axis is 1, update tensor of shape
|
|
2814
|
-
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
|
|
2815
|
-
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
|
|
2816
|
-
Valid_seq_len tensor of shape :math:`(batch\_size)`.
|
|
2817
|
-
batch_index (Tensor): The batch_index tensor with data type of int64.
|
|
2818
|
-
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
|
|
2819
|
-
seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
|
|
2820
|
-
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2821
|
-
New_max_seq_len tensor of shape :math:`(1)`.
|
|
2822
|
-
Indicate that user want to change the shape of cache tensor from
|
|
2823
|
-
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
|
|
2824
|
-
:math:
|
|
2825
|
-
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
|
|
2826
|
-
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
|
|
2827
|
-
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2828
|
-
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
|
|
2829
|
-
|
|
2830
|
-
Outputs:
|
|
2831
|
-
With same data type and same shape as `cache` tensor.
|
|
2832
|
-
|
|
2833
|
-
Supported Platforms:
|
|
2834
|
-
``Ascend``
|
|
2835
|
-
|
|
2836
|
-
Examples:
|
|
2837
|
-
>>> from mindspore.ops.operations import _inner_ops
|
|
2838
|
-
>>> b = 4
|
|
2839
|
-
>>> h = 40
|
|
2840
|
-
>>> max_s = 1024
|
|
2841
|
-
>>> s = 1
|
|
2842
|
-
>>> d = 128
|
|
2843
|
-
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
|
|
2844
|
-
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
|
|
2845
|
-
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
|
|
2846
|
-
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
|
|
2847
|
-
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2848
|
-
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2849
|
-
>>> decoder_kv_cache = _inner_ops.DecoderKVCache()
|
|
2850
|
-
>>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
|
|
2851
|
-
>>> print(cache)
|
|
2510
|
+
class _VirtualConverterEnd(PrimitiveWithInfer):
|
|
2852
2511
|
"""
|
|
2853
|
-
|
|
2854
|
-
def __init__(self):
|
|
2855
|
-
"""Initialize DecoderKVCache."""
|
|
2856
|
-
self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
|
|
2857
|
-
"new_max_seq_len", "cur_max_seq_len"],
|
|
2858
|
-
outputs=["out"])
|
|
2859
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
2860
|
-
|
|
2861
|
-
|
|
2862
|
-
class _MirrorSilentCheck(PrimitiveWithInfer):
|
|
2512
|
+
Auto parallel virtual operator.
|
|
2863
2513
|
"""
|
|
2864
|
-
The operator _MirrorSilentCheck implements accuracy-sensitive detection on the tensor input in backpropagator.
|
|
2865
|
-
Call _MirrorSilentCheck in method __call__ of derived class to implement accuracy-sensitive detection.
|
|
2866
|
-
|
|
2867
|
-
Inputs:
|
|
2868
|
-
- **input** (Tensor) : The tensor used for detection.
|
|
2869
|
-
Its data type must be mindspore.float16, mindspore.float32 or mindspore.bfloat16.
|
|
2870
|
-
- **pre_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2871
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2872
|
-
- **min_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2873
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2874
|
-
- **max_val** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2875
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2876
|
-
- **cnt** (Parameter(Tensor)) : Support parameter in accuracy-sensitive detection.
|
|
2877
|
-
Please only generated by method generate_params() of ASDBase.
|
|
2878
|
-
After each invocation of _MirrorSilentCheck, increment the value of cnt by one.
|
|
2879
2514
|
|
|
2880
|
-
Outputs:
|
|
2881
|
-
- **output** (Tensor) - Same shape, type and value as `input`.
|
|
2882
|
-
"""
|
|
2883
2515
|
@prim_attr_register
|
|
2884
|
-
def __init__(self,
|
|
2885
|
-
|
|
2886
|
-
self.
|
|
2887
|
-
self.thresh_l1 = upper_thresh[0]
|
|
2888
|
-
self.coeff_l1 = sigma_thresh[0]
|
|
2889
|
-
self.thresh_l2 = upper_thresh[1]
|
|
2890
|
-
self.coeff_l2 = sigma_thresh[1]
|
|
2891
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
2892
|
-
|
|
2893
|
-
def parse_thresh(self, env_var_name, default_value, min_value):
|
|
2894
|
-
env_var = os.environ.get(env_var_name, default=default_value)
|
|
2895
|
-
thresh = [value.strip() for value in env_var.split(",")]
|
|
2896
|
-
if len(thresh) != 2 or not all(value.isdigit() for value in thresh):
|
|
2897
|
-
thresh = default_value.split(",")
|
|
2898
|
-
thresh = [float(max(int(value), min_value)) for value in thresh]
|
|
2899
|
-
if thresh[0] <= thresh[1]:
|
|
2900
|
-
thresh = [float(value) for value in default_value.split(",")]
|
|
2901
|
-
|
|
2902
|
-
return thresh
|
|
2903
|
-
|
|
2904
|
-
def get_thresh(self):
|
|
2905
|
-
upper_thresh = self.parse_thresh("NPU_ASD_UPPER_THRESH", "1000000,10000", 3)
|
|
2906
|
-
sigma_thresh = self.parse_thresh("NPU_ASD_SIGMA_THRESH", "100000,5000", 3)
|
|
2907
|
-
return upper_thresh, sigma_thresh
|
|
2516
|
+
def __init__(self, input_nums):
|
|
2517
|
+
"""Initialize _VirtualConverterEnd."""
|
|
2518
|
+
self.input_nums = input_nums
|
|
2908
2519
|
|
|
2909
|
-
def infer_shape(self,
|
|
2910
|
-
return
|
|
2911
|
-
|
|
2912
|
-
def infer_dtype(self, x_dtype, pre_dtype, min_dtype, max_dtype, n_dtype, loss_scale_dtype):
|
|
2913
|
-
return x_dtype
|
|
2520
|
+
def infer_shape(self, *args):
|
|
2521
|
+
return (args[0][0] * self.input_nums,) + tuple(args[0][1:])
|
|
2914
2522
|
|
|
2523
|
+
def infer_dtype(self, *args):
|
|
2524
|
+
return args[0]
|
|
2915
2525
|
|
|
2916
|
-
class PromptKVCache(Primitive):
|
|
2917
|
-
r"""
|
|
2918
|
-
The PromptKVCache is used for prefill the KVCache of transformer network.
|
|
2919
2526
|
|
|
2920
|
-
|
|
2921
|
-
|
|
2922
|
-
|
|
2923
|
-
|
|
2924
|
-
When seq_len_axis is 1, cache tensor of shape
|
|
2925
|
-
:math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
|
|
2926
|
-
update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
|
|
2927
|
-
When seq_len_axis is 2, update tensor of shape
|
|
2928
|
-
:math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
|
|
2929
|
-
When seq_len_axis is 1, update tensor of shape
|
|
2930
|
-
:math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
|
|
2931
|
-
valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
|
|
2932
|
-
Valid_seq_len tensor of shape :math:`(batch\_size)`.
|
|
2933
|
-
batch_index (Tensor): The batch_index tensor with data type of int64.
|
|
2934
|
-
Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
|
|
2935
|
-
seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
|
|
2936
|
-
new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2937
|
-
New_max_seq_len tensor of shape :math:`(1)`.
|
|
2938
|
-
Indicate that user want to change the shape of cache tensor from
|
|
2939
|
-
:math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
|
|
2940
|
-
:math:
|
|
2941
|
-
`(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
|
|
2942
|
-
to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
|
|
2943
|
-
cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
|
|
2944
|
-
Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
|
|
2945
|
-
align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
|
|
2527
|
+
class _VirtualConverterBegin(PrimitiveWithInfer):
|
|
2528
|
+
"""
|
|
2529
|
+
Auto parallel virtual operator.
|
|
2530
|
+
"""
|
|
2946
2531
|
|
|
2947
|
-
|
|
2948
|
-
|
|
2532
|
+
@prim_attr_register
|
|
2533
|
+
def __init__(self, output_nums):
|
|
2534
|
+
"""Initialize _VirtualConverterBegin."""
|
|
2535
|
+
self.output_nums = output_nums
|
|
2949
2536
|
|
|
2950
|
-
|
|
2951
|
-
|
|
2537
|
+
def infer_shape(self, arg):
|
|
2538
|
+
if self.output_nums == 0:
|
|
2539
|
+
return ValueError("output_nums can\'t be zero.")
|
|
2540
|
+
new_arg = (arg[0] / self.output_nums,) + tuple(arg[1:])
|
|
2541
|
+
return (new_arg,) * self.output_nums
|
|
2952
2542
|
|
|
2953
|
-
|
|
2954
|
-
|
|
2955
|
-
>>> from mindspore.ops.operations import _inner_ops
|
|
2956
|
-
>>> b = 4
|
|
2957
|
-
>>> h = 40
|
|
2958
|
-
>>> max_s = 1024
|
|
2959
|
-
>>> s = 256
|
|
2960
|
-
>>> d = 128
|
|
2961
|
-
>>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
|
|
2962
|
-
>>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
|
|
2963
|
-
>>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
|
|
2964
|
-
>>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
|
|
2965
|
-
>>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2966
|
-
>>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
|
|
2967
|
-
>>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
|
|
2968
|
-
>>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
|
|
2969
|
-
>>> print(cache)
|
|
2970
|
-
"""
|
|
2971
|
-
@prim_attr_register
|
|
2972
|
-
def __init__(self, padding_mode="right"):
|
|
2973
|
-
"""Initialize PromptKVCache."""
|
|
2974
|
-
self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
|
|
2975
|
-
"new_max_seq_len", "cur_max_seq_len"],
|
|
2976
|
-
outputs=["out"])
|
|
2977
|
-
self.add_prim_attr('side_effect_mem', True)
|
|
2978
|
-
self.padding_mode = padding_mode
|
|
2543
|
+
def infer_dtype(self, arg):
|
|
2544
|
+
return (arg,) * self.output_nums
|