mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.10__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +46 -19
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +38 -0
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +488 -528
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
"""The impl of flash attention"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
import mindspore.ops as ops
|
|
18
|
-
|
|
18
|
+
import mindspore.common.dtype as mstype
|
|
19
19
|
from mindspore.ops import Custom
|
|
20
20
|
from mindspore.ops import DataType
|
|
21
21
|
from mindspore.ops import TBERegOp
|
|
@@ -39,31 +39,28 @@ cus_flash_atten_op_info = TBERegOp("FlashAttentionPrimitive") \
|
|
|
39
39
|
.input(0, "query", False, "required", "all") \
|
|
40
40
|
.input(1, "key", False, "required", "all") \
|
|
41
41
|
.input(2, "value", False, "required", "all") \
|
|
42
|
-
.input(3, "
|
|
43
|
-
.input(4, "
|
|
44
|
-
.input(5, "
|
|
45
|
-
.input(6, "alibi_mask", False, "optional", "all") \
|
|
42
|
+
.input(3, "attn_mask", False, "optional", "all") \
|
|
43
|
+
.input(4, "dropout_mask", False, "optional", "all") \
|
|
44
|
+
.input(5, "alibi_mask", False, "optional", "all") \
|
|
46
45
|
.output(0, "output", False, "required", "all") \
|
|
47
46
|
.output(1, "rowsum", False, "required", "all") \
|
|
48
47
|
.output(2, "rowmax", False, "required", "all") \
|
|
49
|
-
.dtype_format(DataType.
|
|
50
|
-
DataType.
|
|
51
|
-
DataType.
|
|
52
|
-
DataType.
|
|
53
|
-
DataType.F16_Default,
|
|
54
|
-
DataType.F16_Default,
|
|
55
|
-
DataType.F16_Default,
|
|
48
|
+
.dtype_format(DataType.F16_FracNZ,
|
|
49
|
+
DataType.F16_FracNZ,
|
|
50
|
+
DataType.F16_FracNZ,
|
|
51
|
+
DataType.F16_FracNZ,
|
|
56
52
|
DataType.F16_Default,
|
|
53
|
+
DataType.F16_FracNZ,
|
|
54
|
+
DataType.F16_FracNZ,
|
|
57
55
|
DataType.F16_Default,
|
|
58
56
|
DataType.F16_Default) \
|
|
59
|
-
.dtype_format(DataType.
|
|
60
|
-
DataType.
|
|
61
|
-
DataType.
|
|
62
|
-
DataType.
|
|
63
|
-
DataType.F16_Default,
|
|
64
|
-
DataType.F16_Default,
|
|
65
|
-
DataType.F16_Default,
|
|
57
|
+
.dtype_format(DataType.F16_FracNZ,
|
|
58
|
+
DataType.F16_FracNZ,
|
|
59
|
+
DataType.F16_FracNZ,
|
|
60
|
+
DataType.F16_FracNZ,
|
|
66
61
|
DataType.F16_Default,
|
|
62
|
+
DataType.F16_FracNZ,
|
|
63
|
+
DataType.F16_FracNZ,
|
|
67
64
|
DataType.F32_Default,
|
|
68
65
|
DataType.F16_Default) \
|
|
69
66
|
.get_op_info()
|
|
@@ -88,41 +85,38 @@ cus_flash_atten_grad_op_info = TBERegOp("FlashAttentionGradPrimitive") \
|
|
|
88
85
|
.input(4, "do", False, "required", "all") \
|
|
89
86
|
.input(5, "rowsum", False, "required", "all") \
|
|
90
87
|
.input(6, "rowmax", False, "required", "all") \
|
|
91
|
-
.input(7, "
|
|
92
|
-
.input(8, "
|
|
93
|
-
.input(9, "
|
|
94
|
-
.input(10, "alibi_mask", False, "optional", "all") \
|
|
88
|
+
.input(7, "attn_mask", False, "optional", "all") \
|
|
89
|
+
.input(8, "dropout_mask", False, "optional", "all") \
|
|
90
|
+
.input(9, "alibi_mask", False, "optional", "all") \
|
|
95
91
|
.output(0, "dq", False, "required", "all") \
|
|
96
92
|
.output(1, "dk", False, "required", "all") \
|
|
97
93
|
.output(2, "dv", False, "required", "all") \
|
|
98
|
-
.dtype_format(DataType.
|
|
99
|
-
DataType.
|
|
100
|
-
DataType.
|
|
101
|
-
DataType.
|
|
102
|
-
DataType.
|
|
103
|
-
DataType.F16_Default,
|
|
104
|
-
DataType.F16_Default,
|
|
105
|
-
DataType.
|
|
106
|
-
DataType.F16_Default,
|
|
107
|
-
DataType.
|
|
108
|
-
DataType.
|
|
109
|
-
DataType.
|
|
110
|
-
DataType.
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
DataType.
|
|
114
|
-
DataType.
|
|
115
|
-
DataType.
|
|
116
|
-
DataType.F16_Default,
|
|
94
|
+
.dtype_format(DataType.F16_FracNZ,
|
|
95
|
+
DataType.F16_FracNZ,
|
|
96
|
+
DataType.F16_FracNZ,
|
|
97
|
+
DataType.F16_FracNZ,
|
|
98
|
+
DataType.F16_FracNZ,
|
|
99
|
+
DataType.F16_Default,
|
|
100
|
+
DataType.F16_Default,
|
|
101
|
+
DataType.F16_FracNZ,
|
|
102
|
+
DataType.F16_Default,
|
|
103
|
+
DataType.F16_FracNZ,
|
|
104
|
+
DataType.F32_FracNZ,
|
|
105
|
+
DataType.F32_FracNZ,
|
|
106
|
+
DataType.F32_FracNZ) \
|
|
107
|
+
.dtype_format(DataType.F16_FracNZ,
|
|
108
|
+
DataType.F16_FracNZ,
|
|
109
|
+
DataType.F16_FracNZ,
|
|
110
|
+
DataType.F16_FracNZ,
|
|
111
|
+
DataType.F16_FracNZ,
|
|
117
112
|
DataType.F32_Default,
|
|
118
113
|
DataType.F16_Default,
|
|
119
|
-
DataType.
|
|
120
|
-
DataType.F16_Default,
|
|
114
|
+
DataType.F16_FracNZ,
|
|
121
115
|
DataType.F16_Default,
|
|
122
|
-
DataType.
|
|
123
|
-
DataType.
|
|
124
|
-
DataType.
|
|
125
|
-
DataType.
|
|
116
|
+
DataType.F16_FracNZ,
|
|
117
|
+
DataType.F32_FracNZ,
|
|
118
|
+
DataType.F32_FracNZ,
|
|
119
|
+
DataType.F32_FracNZ) \
|
|
126
120
|
.get_op_info()
|
|
127
121
|
|
|
128
122
|
|
|
@@ -131,11 +125,11 @@ def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
|
|
|
131
125
|
"""get flash attention grad"""
|
|
132
126
|
|
|
133
127
|
def infer_shape(q_shape, k_shape, v_shape, o_shape, do_shape, l_shape, m_shape,
|
|
134
|
-
|
|
128
|
+
att_mask_shape, dropout_mask_shape, alibi_mask_shape):
|
|
135
129
|
return q_shape, k_shape, v_shape
|
|
136
130
|
|
|
137
131
|
def infer_dtype(q_dtype, k_dtype, v_dtype, o_dytpe, do_dtype, l_dtype, m_dtype,
|
|
138
|
-
|
|
132
|
+
attn_mask_dtype, dropout_mask_dtype, alibi_mask_type):
|
|
139
133
|
return mstype.float32, mstype.float32, mstype.float32
|
|
140
134
|
|
|
141
135
|
fa_grad = Custom(flash_attention_grad, out_shape=infer_shape,
|
|
@@ -145,20 +139,20 @@ def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
|
|
|
145
139
|
fa_grad.add_prim_attr("high_precision", high_precision)
|
|
146
140
|
fa_grad.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
|
|
147
141
|
fa_grad.init_prim_io_names(
|
|
148
|
-
inputs=["query", "key", "value", "output", "do", "rowsum", "rowmax", "
|
|
142
|
+
inputs=["query", "key", "value", "output", "do", "rowsum", "rowmax", "attn_mask", "dropout_mask",
|
|
149
143
|
"alibi_mask"],
|
|
150
144
|
outputs=["dq", "dk", "dv"]
|
|
151
145
|
)
|
|
152
146
|
|
|
153
|
-
def bprop(query, key, value,
|
|
147
|
+
def bprop(query, key, value, attn_mask, dropout_mask, alibi_mask, out, douts):
|
|
154
148
|
output, rowsum, rowmax = out
|
|
155
149
|
dout, _, _ = douts
|
|
156
|
-
dq, dk, dv = fa_grad(query, key, value, output, dout, rowsum, rowmax,
|
|
150
|
+
dq, dk, dv = fa_grad(query, key, value, output, dout, rowsum, rowmax, attn_mask, dropout_mask,
|
|
157
151
|
alibi_mask)
|
|
158
152
|
dq = ops.cast(dq, mstype.float16)
|
|
159
153
|
dk = ops.cast(dk, mstype.float16)
|
|
160
154
|
dv = ops.cast(dv, mstype.float16)
|
|
161
|
-
return dq, dk, dv, zeros_like(
|
|
155
|
+
return dq, dk, dv, zeros_like(attn_mask), \
|
|
162
156
|
zeros_like(dropout_mask), zeros_like(alibi_mask)
|
|
163
157
|
|
|
164
158
|
return bprop
|
|
@@ -167,7 +161,7 @@ def get_flash_attention_grad(prev_block_num=65536, next_block_num=65536,
|
|
|
167
161
|
def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_name='sparse', high_precision=False):
|
|
168
162
|
"""get_flash_attention"""
|
|
169
163
|
|
|
170
|
-
def infer_shape(q_shape, k_shape, v_shape,
|
|
164
|
+
def infer_shape(q_shape, k_shape, v_shape, attn_mask_shape=None,
|
|
171
165
|
dropout_mask_shape=None, alibi_mask_shape=None):
|
|
172
166
|
"""infer shape"""
|
|
173
167
|
batch, hidden_size, seq_len, _ = q_shape
|
|
@@ -175,7 +169,7 @@ def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_
|
|
|
175
169
|
m_shape = (batch, hidden_size, seq_len)
|
|
176
170
|
return q_shape, l_shape, m_shape
|
|
177
171
|
|
|
178
|
-
def infer_dtype(q_dtype, k_dtype, v_dtype,
|
|
172
|
+
def infer_dtype(q_dtype, k_dtype, v_dtype, attn_mask_dtype=None,
|
|
179
173
|
dropout_mask_dtype=None, alibi_mask_type=None):
|
|
180
174
|
"""infer type"""
|
|
181
175
|
l_dtype = mstype.float16
|
|
@@ -192,7 +186,7 @@ def get_flash_attention(prev_block_num=65536, next_block_num=65536, tiling_stgy_
|
|
|
192
186
|
fa_forward.add_prim_attr("high_precision", high_precision)
|
|
193
187
|
fa_forward.add_prim_attr("tiling_stgy_name", tiling_stgy_name)
|
|
194
188
|
fa_forward.init_prim_io_names(
|
|
195
|
-
inputs=["query", "key", "value", "
|
|
189
|
+
inputs=["query", "key", "value", "attn_mask", "dropout_mask", "alibi_mask"],
|
|
196
190
|
outputs=["output", "rowsum", "rowmax"]
|
|
197
191
|
)
|
|
198
192
|
|
|
@@ -19,7 +19,6 @@ from mindspore.ops._op_impl._custom_op.flash_attention.constants import DTYPE_SI
|
|
|
19
19
|
from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP16
|
|
20
20
|
from mindspore.ops._op_impl._custom_op.flash_attention.constants import FP32
|
|
21
21
|
from mindspore.ops._op_impl._custom_op.flash_attention.constants import L0C
|
|
22
|
-
from mindspore.ops._op_impl._custom_op.flash_attention.constants import L1
|
|
23
22
|
from mindspore.ops._op_impl._custom_op.flash_attention.constants import UB
|
|
24
23
|
|
|
25
24
|
|
|
@@ -179,7 +178,7 @@ class TikOpsUtils:
|
|
|
179
178
|
def broadcast(self, vec_ub, shape):
|
|
180
179
|
""" broadcast a vector to a matrix
|
|
181
180
|
:param vec_ub: a tensor in UB with shape of (M,), and dtype is float16
|
|
182
|
-
:param shape: the target shape, a tuple with value (M, N)
|
|
181
|
+
:param shape: the target shape, a tuple with value (M, N), M and N are integer multiples of 16
|
|
183
182
|
:return: a tensor in UB with shape of (M, N)
|
|
184
183
|
"""
|
|
185
184
|
M, N = shape
|
|
@@ -321,27 +320,16 @@ class TikOpsUtils:
|
|
|
321
320
|
)
|
|
322
321
|
return vec_rec_ub
|
|
323
322
|
|
|
324
|
-
def row_sum_cube_impl(self, matrix_l1_K1MK0_ed, rowsum_ub, m, k, precision_type):
|
|
323
|
+
def row_sum_cube_impl(self, matrix_l1_K1MK0_ed, right_all_one_matrix_l1, rowsum_ub, m, k, precision_type):
|
|
325
324
|
"""用cube实现矩阵行和:右乘一个shape=(n,1)全一矩阵
|
|
326
325
|
:param matrix_l1_K1MK0_ed: input tensor with shape (K1, M, K0)
|
|
327
|
-
:param
|
|
326
|
+
:param right_all_one_matrix_l1: input tensor with shape (K, 16)
|
|
327
|
+
:param rowsum_ub: output tensor stores the row sum of input tensor
|
|
328
328
|
:param m: actual tensor height
|
|
329
329
|
:param k: actual tensor width
|
|
330
330
|
:return: row sum of the output tensor
|
|
331
331
|
"""
|
|
332
332
|
K1, M, K0 = matrix_l1_K1MK0_ed.shape
|
|
333
|
-
K = K1 * K0
|
|
334
|
-
|
|
335
|
-
# 构造全一右矩阵,由于cube无法处理shape=(n, 1),所以shape=(n, 16),全一矩阵不需分形
|
|
336
|
-
right_all_one_matrix_ub = self.tik_instance.Tensor(
|
|
337
|
-
FP16, (K, 16), name="right_all_one_matrix_ub", scope=UB
|
|
338
|
-
)
|
|
339
|
-
self.tik_instance.h_duplicate(right_all_one_matrix_ub, 1.0)
|
|
340
|
-
right_all_one_matrix_l1 = self.tik_instance.Tensor(
|
|
341
|
-
FP16, (K1 * K0, 16), name="right_all_one_matrix_l1", scope=L1
|
|
342
|
-
)
|
|
343
|
-
self.cont_data_mv_1_bust(dst=right_all_one_matrix_l1, src=right_all_one_matrix_ub, burst=K)
|
|
344
|
-
|
|
345
333
|
# 调用matmul实现rowsum,结果shape=(m, 16),取每行的第一个数
|
|
346
334
|
with self.tik_instance.new_stmt_scope(disable_sync=False):
|
|
347
335
|
row_sum_ub_N1MN0 = self.matmul_compute(matrix_l1_K1MK0_ed, right_all_one_matrix_l1, m, k, 16,
|
|
@@ -352,6 +340,7 @@ class TikOpsUtils:
|
|
|
352
340
|
cur_row_sum = self.tik_instance.Scalar(FP32, init_value=row_sum_ub_MN_ed[idx, 0])
|
|
353
341
|
rowsum_ub[idx].set_as(cur_row_sum)
|
|
354
342
|
else:
|
|
343
|
+
# row_sum_ub_MN_ed 先转置,然后取一行, 替换原来按行操作: lij_ub[i].set_as(row_sum_ub_MN_ed[i, 0])
|
|
355
344
|
row_sum_ub_trans = self.tik_instance.Tensor(FP16, (16, M), name="row_sum_ub_trans", scope=UB)
|
|
356
345
|
row_sum_ub_trans = self.transpose_matrix(row_sum_ub_MN_ed, row_sum_ub_trans, M, True)
|
|
357
346
|
self.cont_data_mv_1_bust(dst=rowsum_ub, src=row_sum_ub_trans, burst=M // 16)
|
|
@@ -409,7 +398,7 @@ class TikOpsUtils:
|
|
|
409
398
|
offset = vec_len - a_burst_num
|
|
410
399
|
last_blk_ub = self.tik_instance.Tensor(FP16, (a_burst_num,), name="last_blk_ub", scope=UB)
|
|
411
400
|
self.cont_data_mv_1_bust(dst=last_blk_ub, src=src_tensor[gm_offset + offset], burst=1)
|
|
412
|
-
with self.tik_instance.for_range(0, a_burst_num) as idx: # offset非32bytes
|
|
401
|
+
with self.tik_instance.for_range(0, a_burst_num) as idx: # offset非32bytes对齐, 无法用datamove
|
|
413
402
|
dst_tensor[offset + idx].set_as(last_blk_ub[idx])
|
|
414
403
|
|
|
415
404
|
def move_vector_from_ub_to_gm(self, dst_tensor, src_tensor, gm_offset, block_h):
|
|
@@ -29,7 +29,7 @@ class WukongTiling(TilingStrategy):
|
|
|
29
29
|
反向的空间分布待详细分析
|
|
30
30
|
N = (4096, 1024, 256, 64) 或 77
|
|
31
31
|
Nq = (4096, 1024, 256, 64)
|
|
32
|
-
d = dv = (40, 80, 160
|
|
32
|
+
d = dv = (40, 80, 160, 160)
|
|
33
33
|
"""
|
|
34
34
|
if self.N <= 77: # [77, 64]
|
|
35
35
|
# cross-attention or self-attention of (64, 64, 160)
|
|
@@ -108,6 +108,7 @@ from .search_sorted import _search_sorted_aicpu
|
|
|
108
108
|
from .stack import _stack_aicpu
|
|
109
109
|
from .unstack import _unstack_aicpu
|
|
110
110
|
from .unsorted_segment_sum import _unsorted_segment_sum_aicpu
|
|
111
|
+
from .unsorted_segment_prod import _unsorted_segment_prod_aicpu
|
|
111
112
|
from .addcmul import _addcmul_aicpu
|
|
112
113
|
from .uniform_candidate_sampler import _uniform_candidate_sampler_aicpu
|
|
113
114
|
from .log_uniform_candidate_sampler import _log_uniform_candidate_sampler_aicpu
|
|
@@ -145,6 +146,7 @@ from .upsample_trilinear_3d import _upsample_trilinear_3d_aicpu
|
|
|
145
146
|
from .upsample_trilinear_3d_grad import _upsample_trilinear_3d_grad_aicpu
|
|
146
147
|
from .upper_bound import _upper_bound_aicpu
|
|
147
148
|
from .cache_swap_table import _cache_swap_table_aicpu
|
|
149
|
+
from .uniform import _uniform_aicpu
|
|
148
150
|
from .uniform_int import _uniform_int_aicpu
|
|
149
151
|
from .uniform_real import _uniform_real_aicpu
|
|
150
152
|
from .standard_laplace import _standard_laplace_aicpu
|
|
@@ -156,12 +158,13 @@ from .fused_sparse_adam import _fused_sparse_adam_aicpu
|
|
|
156
158
|
from .fused_sparse_lazy_adam import _fused_sparse_lazy_adam_aicpu
|
|
157
159
|
from .fused_sparse_ftrl import _fused_sparse_ftrl_aicpu
|
|
158
160
|
from .sparse_fill_empty_rows_grad import _sparse_fill_empty_rows_grad_aicpu
|
|
161
|
+
from .sparse_reorder import _sparse_reorder_aicpu
|
|
159
162
|
from .sparse_reshape import _sparse_reshape_aicpu
|
|
160
163
|
from .sparse_segment_sqrt_n_grad import _sparse_segment_sqrt_n_grad_aicpu
|
|
161
164
|
from .sparse_segment_sum import _sparse_segment_sum_aicpu
|
|
162
165
|
from .sparse_segment_sum_with_num_segments import _sparse_segment_sum_with_num_segments_aicpu
|
|
163
166
|
from .sparse_softmax_cross_entropy_with_logits_v2 import _sparse_softmax_cross_entropy_with_logits_v2_aicpu
|
|
164
|
-
from .
|
|
167
|
+
from .sparse_sparse_maximum import _sparse_sparse_maximum_aicpu
|
|
165
168
|
from .split import _split_aicpu
|
|
166
169
|
from .transpose import _transpose_aicpu
|
|
167
170
|
from .tril_indices import _tril_indices_aicpu
|
|
@@ -205,6 +208,7 @@ from .environ_get import _environ_get_aicpu
|
|
|
205
208
|
from .environ_destroy_all import _environ_destroy_all_aicpu
|
|
206
209
|
from .cross import _cross_aicpu
|
|
207
210
|
from .check_numerics import _check_numerics_aicpu
|
|
211
|
+
from .cummax import _cummax_aicpu
|
|
208
212
|
from .cumsum import _cumsum_aicpu
|
|
209
213
|
from .round import _round_aicpu
|
|
210
214
|
from .stft import _stft_aicpu
|
|
@@ -229,6 +233,7 @@ from .scatter_nd_update import _scatter_nd_update_aicpu
|
|
|
229
233
|
from .scatter_nd_max import _scatter_nd_max_aicpu
|
|
230
234
|
from .conj import _conj_aicpu
|
|
231
235
|
from .scatter_nd_min import _scatter_nd_min_aicpu
|
|
236
|
+
from .scatter_add_with_axis import _scatter_add_with_axis_aicpu
|
|
232
237
|
from .compare_and_bitpack import _compare_and_bitpack_aicpu
|
|
233
238
|
from .addcdiv import _addcdiv_aicpu
|
|
234
239
|
from .unique_consecutive import _unique_consecutive_aicpu
|
|
@@ -241,8 +246,8 @@ from .reservoir_replay_buffer import _rrb_push_op_cpu
|
|
|
241
246
|
from .reservoir_replay_buffer import _rrb_sample_op_cpu
|
|
242
247
|
from .reservoir_replay_buffer import _rrb_destroy_op_cpu
|
|
243
248
|
from .concat_offset import _concat_offset_aicpu
|
|
244
|
-
from .concat_offset_v1 import _concat_offset_v1_aicpu
|
|
245
249
|
from .range import _range_aicpu
|
|
250
|
+
from .range_v2 import _range_v2_aicpu
|
|
246
251
|
from .slice_grad import _slice_grad_aicpu
|
|
247
252
|
from .median import _median_aicpu
|
|
248
253
|
from .median_grad import _median_grad_aicpu
|
|
@@ -272,6 +277,7 @@ from .complex import _complex_aicpu
|
|
|
272
277
|
from .complex_abs import _complex_abs_aicpu
|
|
273
278
|
from .concat import _concat_aicpu
|
|
274
279
|
from .cos import _cos_aicpu
|
|
280
|
+
from .count_nonzero import _count_nonzero_aicpu
|
|
275
281
|
from .csr_sparse_matrix_to_dense import _csr_sparse_matrix_to_dense_aicpu
|
|
276
282
|
from .cumprod import _cumprod_aicpu
|
|
277
283
|
from .exp import _exp_aicpu
|
|
@@ -340,6 +346,7 @@ from .hypot import _hypot_aicpu
|
|
|
340
346
|
from .identity_n import _identity_n_aicpu
|
|
341
347
|
from .index_fill import _index_fill_aicpu
|
|
342
348
|
from .index_put import _index_put_aicpu
|
|
349
|
+
from .inplace_index_add import _inplace_index_add_aicpu
|
|
343
350
|
from .kldivloss import _kldiv_loss_aicpu
|
|
344
351
|
from .kldivlossgrad import _kldiv_loss_grad_aicpu
|
|
345
352
|
from .lcm import _lcm_aicpu
|
|
@@ -400,6 +407,9 @@ from .non_deterministic_ints import _non_deterministic_ints_aicpu
|
|
|
400
407
|
from .pow import _pow_aicpu
|
|
401
408
|
from .real import _real_aicpu
|
|
402
409
|
from .resize_area import _resize_area_aicpu
|
|
410
|
+
from .segment_mean import _segment_mean_aicpu
|
|
411
|
+
from .segment_min import _segment_min_aicpu
|
|
412
|
+
from .segment_prod import _segment_prod_aicpu
|
|
403
413
|
from .segment_sum import _segment_sum_aicpu
|
|
404
414
|
from .set_size import _set_size_aicpu
|
|
405
415
|
from .slice import _slice_aicpu
|
|
@@ -411,6 +421,7 @@ from .sparse_tensor_dense_mat_mul import _sparse_tensor_dense_mat_mul_aicpu
|
|
|
411
421
|
from .trace import _trace_aicpu
|
|
412
422
|
from .tracegrad import _tracegrad_aicpu
|
|
413
423
|
from .tridiagonal_solve import _tridiagonal_solve_aicpu
|
|
424
|
+
from .tridiagonal_matmul import _tridiagonal_matmul_aicpu
|
|
414
425
|
from .truncated_normal import _truncated_normal_aicpu
|
|
415
426
|
from .glu import _glu_aicpu
|
|
416
427
|
from .deformable_offsets import _deformable_offsets_aicpu
|
|
@@ -426,3 +437,4 @@ from .sequence_concat import _sequence_concat_aicpu
|
|
|
426
437
|
from .sequence_stack import _sequence_stack_aicpu
|
|
427
438
|
from .affine_grid import _affine_grid_aicpu
|
|
428
439
|
from .depth_to_space import _depth_to_space_aicpu
|
|
440
|
+
from .eps import _eps_aicpu
|
|
@@ -29,9 +29,9 @@ add_op_info = AiCPURegOp("Add") \
|
|
|
29
29
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
|
30
30
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
|
31
31
|
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
|
|
32
|
-
.dtype_format(DataType.U16_Default, DataType.
|
|
33
|
-
.dtype_format(DataType.U32_Default, DataType.
|
|
34
|
-
.dtype_format(DataType.U64_Default, DataType.
|
|
32
|
+
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
|
|
33
|
+
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
|
|
34
|
+
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
|
|
35
35
|
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
|
|
36
36
|
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
|
|
37
37
|
.get_op_info()
|
|
@@ -31,7 +31,6 @@ bias_add_grad_op_info = AiCPURegOp("BiasAddGrad") \
|
|
|
31
31
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
|
32
32
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
|
33
33
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
|
34
|
-
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
|
35
34
|
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
|
|
36
35
|
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
|
|
37
36
|
.get_op_info()
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""CountNonZero op"""
|
|
17
|
+
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
|
18
|
+
|
|
19
|
+
count_nonzero_op_info = AiCPURegOp("CountNonZero") \
|
|
20
|
+
.fusion_type("OPAQUE") \
|
|
21
|
+
.input(0, "x", "required") \
|
|
22
|
+
.output(0, "y", "required") \
|
|
23
|
+
.attr("dims", "listInt")\
|
|
24
|
+
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
|
|
25
|
+
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
|
|
26
|
+
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
|
|
27
|
+
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
|
28
|
+
.dtype_format(DataType.U8_Default, DataType.I64_Default) \
|
|
29
|
+
.dtype_format(DataType.U16_Default, DataType.I64_Default) \
|
|
30
|
+
.dtype_format(DataType.U32_Default, DataType.I64_Default) \
|
|
31
|
+
.dtype_format(DataType.U64_Default, DataType.I64_Default) \
|
|
32
|
+
.dtype_format(DataType.F16_Default, DataType.I64_Default) \
|
|
33
|
+
.dtype_format(DataType.F32_Default, DataType.I64_Default) \
|
|
34
|
+
.dtype_format(DataType.F64_Default, DataType.I64_Default) \
|
|
35
|
+
.dtype_format(DataType.C64_Default, DataType.I64_Default) \
|
|
36
|
+
.dtype_format(DataType.C128_Default, DataType.I64_Default) \
|
|
37
|
+
.get_op_info()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@op_info_register(count_nonzero_op_info)
|
|
41
|
+
def _count_nonzero_aicpu():
|
|
42
|
+
"""CountNonZero AiCPU register"""
|
|
43
|
+
return
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""Eps op"""
|
|
17
|
+
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
|
18
|
+
|
|
19
|
+
eps_op_info = AiCPURegOp("Eps") \
|
|
20
|
+
.fusion_type("OPAQUE") \
|
|
21
|
+
.input(0, "x", "required") \
|
|
22
|
+
.output(0, "y", "required") \
|
|
23
|
+
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
|
24
|
+
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
|
25
|
+
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
|
26
|
+
.get_op_info()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@op_info_register(eps_op_info)
|
|
30
|
+
def _eps_aicpu():
|
|
31
|
+
"""Eps AiCPU register"""
|
|
32
|
+
return
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
|
|
16
|
-
"""
|
|
16
|
+
"""Gamma op"""
|
|
17
17
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
|
18
18
|
|
|
19
19
|
gamma_op_info = AiCPURegOp("Gamma") \
|
|
@@ -32,5 +32,5 @@ gamma_op_info = AiCPURegOp("Gamma") \
|
|
|
32
32
|
|
|
33
33
|
@op_info_register(gamma_op_info)
|
|
34
34
|
def _gamma_aicpu():
|
|
35
|
-
"""
|
|
35
|
+
"""Gamma AiCPU register"""
|
|
36
36
|
return
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -18,15 +18,18 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
|
|
|
18
18
|
log_uniform_candidate_sampler_op_info = AiCPURegOp("LogUniformCandidateSampler") \
|
|
19
19
|
.fusion_type("OPAQUE") \
|
|
20
20
|
.input(0, "true_classes", "required") \
|
|
21
|
+
.input(1, "counts", "required") \
|
|
22
|
+
.input(2, "states", "required") \
|
|
21
23
|
.output(0, "sampled_candidates", "required") \
|
|
22
24
|
.output(1, "true_expected_count", "required") \
|
|
23
|
-
.output(2, "
|
|
25
|
+
.output(2, "sampled_expected_count", "required") \
|
|
24
26
|
.attr("num_true", "int") \
|
|
25
27
|
.attr("num_sampled", "int") \
|
|
26
28
|
.attr("unique", "bool") \
|
|
27
29
|
.attr("range_max", "int") \
|
|
28
30
|
.attr("seed", "int") \
|
|
29
|
-
.dtype_format(DataType.I64_Default, DataType.
|
|
31
|
+
.dtype_format(DataType.I64_Default, DataType.U64_Default, DataType.U64_Default, DataType.I64_Default,
|
|
32
|
+
DataType.F32_Default, DataType.F32_Default) \
|
|
30
33
|
.get_op_info()
|
|
31
34
|
|
|
32
35
|
|
|
@@ -19,7 +19,6 @@ from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataTyp
|
|
|
19
19
|
lu_unpack_grad_op_info = AiCPURegOp("LuUnpackGrad") \
|
|
20
20
|
.fusion_type("OPAQUE") \
|
|
21
21
|
.attr("L_grad_flag", "bool") \
|
|
22
|
-
.attr("L_grad_flag", "bool") \
|
|
23
22
|
.input(0, "L_grad", "required") \
|
|
24
23
|
.input(1, "U_grad", "required") \
|
|
25
24
|
.input(2, "LU_data", "required") \
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2022-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,8 +20,8 @@ multinomial_op_info = AiCPURegOp("Multinomial") \
|
|
|
20
20
|
.fusion_type("OPAQUE") \
|
|
21
21
|
.input(0, "input", "required") \
|
|
22
22
|
.input(1, "num_sample", "required") \
|
|
23
|
-
.input(2, "
|
|
24
|
-
.input(3, "
|
|
23
|
+
.input(2, "counts", "required") \
|
|
24
|
+
.input(3, "states", "required") \
|
|
25
25
|
.output(0, "output", "required") \
|
|
26
26
|
.attr("dtype", "Type") \
|
|
27
27
|
.attr("seed", "int") \
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2022-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -22,21 +22,29 @@ parameterized_truncated_normal_op_info = AiCPURegOp("ParameterizedTruncatedNorma
|
|
|
22
22
|
.input(2, "stdevs", "required") \
|
|
23
23
|
.input(3, "min", "required") \
|
|
24
24
|
.input(4, "max", "required") \
|
|
25
|
+
.input(5, "counts", "required") \
|
|
26
|
+
.input(6, "states", "required") \
|
|
25
27
|
.output(0, "y", "required") \
|
|
26
28
|
.attr("seed", "int")\
|
|
27
29
|
.attr("seed2", "int")\
|
|
28
30
|
.dtype_format(DataType.I32_Default, DataType.F16_Default, DataType.F16_Default,
|
|
29
|
-
DataType.F16_Default, DataType.F16_Default, DataType.
|
|
31
|
+
DataType.F16_Default, DataType.F16_Default, DataType.U64_Default,
|
|
32
|
+
DataType.U64_Default, DataType.F16_Default) \
|
|
30
33
|
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default,
|
|
31
|
-
DataType.F32_Default, DataType.F32_Default, DataType.
|
|
34
|
+
DataType.F32_Default, DataType.F32_Default, DataType.U64_Default,
|
|
35
|
+
DataType.U64_Default, DataType.F32_Default) \
|
|
32
36
|
.dtype_format(DataType.I32_Default, DataType.F64_Default, DataType.F64_Default,
|
|
33
|
-
DataType.F64_Default, DataType.F64_Default, DataType.
|
|
37
|
+
DataType.F64_Default, DataType.F64_Default, DataType.U64_Default,
|
|
38
|
+
DataType.U64_Default, DataType.F64_Default) \
|
|
34
39
|
.dtype_format(DataType.I64_Default, DataType.F16_Default, DataType.F16_Default,
|
|
35
|
-
DataType.F16_Default, DataType.F16_Default, DataType.
|
|
40
|
+
DataType.F16_Default, DataType.F16_Default, DataType.U64_Default,
|
|
41
|
+
DataType.U64_Default, DataType.F16_Default) \
|
|
36
42
|
.dtype_format(DataType.I64_Default, DataType.F32_Default, DataType.F32_Default,
|
|
37
|
-
DataType.F32_Default, DataType.F32_Default, DataType.
|
|
43
|
+
DataType.F32_Default, DataType.F32_Default, DataType.U64_Default,
|
|
44
|
+
DataType.U64_Default, DataType.F32_Default) \
|
|
38
45
|
.dtype_format(DataType.I64_Default, DataType.F64_Default, DataType.F64_Default,
|
|
39
|
-
DataType.F64_Default, DataType.F64_Default, DataType.
|
|
46
|
+
DataType.F64_Default, DataType.F64_Default, DataType.U64_Default,
|
|
47
|
+
DataType.U64_Default, DataType.F64_Default) \
|
|
40
48
|
.get_op_info()
|
|
41
49
|
|
|
42
50
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -21,25 +21,45 @@ random_categorical_op_info = AiCPURegOp("RandomCategorical") \
|
|
|
21
21
|
.input(0, "logits", "required") \
|
|
22
22
|
.input(1, "num_sample", "required") \
|
|
23
23
|
.input(2, "seed", "required") \
|
|
24
|
+
.input(3, "counts", "required") \
|
|
25
|
+
.input(4, "states", "required") \
|
|
24
26
|
.output(0, "output", "required") \
|
|
25
|
-
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.
|
|
26
|
-
|
|
27
|
-
.dtype_format(DataType.
|
|
28
|
-
|
|
29
|
-
.dtype_format(DataType.
|
|
30
|
-
|
|
31
|
-
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.
|
|
32
|
-
|
|
33
|
-
.dtype_format(DataType.
|
|
34
|
-
|
|
35
|
-
.dtype_format(DataType.
|
|
36
|
-
|
|
37
|
-
.dtype_format(DataType.F16_Default, DataType.
|
|
38
|
-
|
|
39
|
-
.dtype_format(DataType.
|
|
40
|
-
|
|
41
|
-
.dtype_format(DataType.
|
|
42
|
-
|
|
27
|
+
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
28
|
+
DataType.U64_Default, DataType.I16_Default) \
|
|
29
|
+
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
30
|
+
DataType.U64_Default, DataType.I16_Default) \
|
|
31
|
+
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
32
|
+
DataType.U64_Default, DataType.I16_Default) \
|
|
33
|
+
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
34
|
+
DataType.U64_Default, DataType.I32_Default) \
|
|
35
|
+
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
36
|
+
DataType.U64_Default, DataType.I32_Default) \
|
|
37
|
+
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
38
|
+
DataType.U64_Default, DataType.I32_Default) \
|
|
39
|
+
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
40
|
+
DataType.U64_Default, DataType.I64_Default) \
|
|
41
|
+
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
42
|
+
DataType.U64_Default, DataType.I64_Default) \
|
|
43
|
+
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default,
|
|
44
|
+
DataType.U64_Default, DataType.I64_Default) \
|
|
45
|
+
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
46
|
+
DataType.U64_Default, DataType.I16_Default) \
|
|
47
|
+
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
48
|
+
DataType.U64_Default, DataType.I16_Default) \
|
|
49
|
+
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
50
|
+
DataType.U64_Default, DataType.I16_Default) \
|
|
51
|
+
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
52
|
+
DataType.U64_Default, DataType.I32_Default) \
|
|
53
|
+
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
54
|
+
DataType.U64_Default, DataType.I32_Default) \
|
|
55
|
+
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
56
|
+
DataType.U64_Default, DataType.I32_Default) \
|
|
57
|
+
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
58
|
+
DataType.U64_Default, DataType.I64_Default) \
|
|
59
|
+
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
60
|
+
DataType.U64_Default, DataType.I64_Default) \
|
|
61
|
+
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.U64_Default,
|
|
62
|
+
DataType.U64_Default, DataType.I64_Default) \
|
|
43
63
|
.get_op_info()
|
|
44
64
|
|
|
45
65
|
@op_info_register(random_categorical_op_info)
|