mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
|
@@ -30,15 +30,16 @@ from mindspore.ops.primitive import Primitive
|
|
|
30
30
|
from mindspore.ops.primitive import PrimitiveWithInfer
|
|
31
31
|
from mindspore.ops.primitive import PrimitiveWithCheck
|
|
32
32
|
from mindspore.ops.primitive import prim_attr_register
|
|
33
|
-
from
|
|
33
|
+
from mindspore.run_check._check_version import AscendEnvChecker
|
|
34
|
+
from ..auto_generate import (CeLU, Flatten, LogSoftmax, LogSoftmaxExt, ReLU, ReLU6, Dense, Tanh,
|
|
34
35
|
Elu, Sigmoid, Softmax, SoftplusExt, HSwish, HSigmoid, AvgPool, BiasAdd,
|
|
35
|
-
NLLLoss, OneHot, GeLU, FastGeLU, PReLU, RmsNorm,
|
|
36
|
+
NLLLoss, OneHot, GeLU, FastGeLU, PReLU, RmsNorm, IncreFlashAttention, MSELossExt,
|
|
36
37
|
GridSampler3D, GridSampler2D, LayerNorm, LayerNormExt, HShrink, AdamWeightDecay, Dropout,
|
|
37
38
|
ApplyRotaryPosEmb, PagedAttention, PagedAttentionMask, ReshapeAndCache,
|
|
38
39
|
FlashAttentionScore, Embedding, UpsampleNearest1D, UpsampleNearest2D,
|
|
39
40
|
UpsampleNearest3D, UpsampleTrilinear3D,
|
|
40
41
|
UpsampleBilinear2D, UpsampleLinear1D,
|
|
41
|
-
BinaryCrossEntropy, BCEWithLogitsLoss)
|
|
42
|
+
BinaryCrossEntropy, BCEWithLogitsLoss, SoftShrink)
|
|
42
43
|
from .manually_defined import BatchNorm
|
|
43
44
|
|
|
44
45
|
|
|
@@ -453,7 +454,7 @@ class ReLUV3(Primitive):
|
|
|
453
454
|
Inputs:
|
|
454
455
|
- **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of
|
|
455
456
|
additional dimensions, data type is
|
|
456
|
-
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore
|
|
457
|
+
`number <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.dtype.html>`_.
|
|
457
458
|
|
|
458
459
|
Outputs:
|
|
459
460
|
Tensor of shape :math:`(N, *)`, with the same type and shape as the `input_x`.
|
|
@@ -569,8 +570,6 @@ class SeLU(Primitive):
|
|
|
569
570
|
self.init_prim_io_names(inputs=['input_x'], outputs=['output'])
|
|
570
571
|
|
|
571
572
|
|
|
572
|
-
|
|
573
|
-
|
|
574
573
|
class FusedBatchNorm(Primitive):
|
|
575
574
|
r"""
|
|
576
575
|
The FusedBatchNorm interface is deprecated, please use the BatchNorm interface.
|
|
@@ -3075,9 +3074,9 @@ class LSTM(Primitive):
|
|
|
3075
3074
|
Args:
|
|
3076
3075
|
input_size (int): Number of features of input.
|
|
3077
3076
|
hidden_size (int): Number of features of hidden layer.
|
|
3078
|
-
num_layers (int): Number of layers of stacked LSTM.
|
|
3079
|
-
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh
|
|
3080
|
-
bidirectional (bool): Specifies whether it is a bidirectional LSTM.
|
|
3077
|
+
num_layers (int): Number of layers of stacked LSTM, , which is only support `1` on CPU.
|
|
3078
|
+
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh` , which is only support `False` on CPU.
|
|
3079
|
+
bidirectional (bool): Specifies whether it is a bidirectional LSTM, , which is only support `False` on CPU.
|
|
3081
3080
|
dropout (float): If not 0, append `Dropout` layer on the outputs of each
|
|
3082
3081
|
LSTM layer except the last layer. The range of dropout is [0.0, 1.0].
|
|
3083
3082
|
proj_size (int): If `proj_size` > 0, a projection of the corresponding size will be used,
|
|
@@ -3776,6 +3775,7 @@ class AdamNoUpdateParam(Primitive):
|
|
|
3776
3775
|
@prim_attr_register
|
|
3777
3776
|
def __init__(self, use_locking=False, use_nesterov=False):
|
|
3778
3777
|
"""Initialize AdamNoUpdateParam."""
|
|
3778
|
+
self.add_prim_attr('side_effect_mem', True)
|
|
3779
3779
|
validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
|
3780
3780
|
validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name)
|
|
3781
3781
|
|
|
@@ -6376,6 +6376,9 @@ class AvgPool3D(Primitive):
|
|
|
6376
6376
|
\frac{1}{d_{ker} * h_{ker} * w_{ker}} \sum_{l=0}^{d_{ker}-1} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1}
|
|
6377
6377
|
\text{input}(N_i, C_j, s_0 \times d + l, s_1 \times h + m, s_2 \times w + n)
|
|
6378
6378
|
|
|
6379
|
+
Note:
|
|
6380
|
+
This interface currently does not support Atlas A2 training series products.
|
|
6381
|
+
|
|
6379
6382
|
Args:
|
|
6380
6383
|
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value,
|
|
6381
6384
|
is an int number that represents depth, height and width are both kernel_size, or a tuple
|
|
@@ -7091,7 +7094,7 @@ class CTCLossV2Grad(Primitive):
|
|
|
7091
7094
|
zero_infinity (bool): Whether to set infinite loss and correlation gradient to zero. Default: ``False`` .
|
|
7092
7095
|
|
|
7093
7096
|
Inputs:
|
|
7094
|
-
- **grad_out** (
|
|
7097
|
+
- **grad_out** (Tensor) - Gradient renewal codfficient, A tensor for shape (N), where N is batch size.
|
|
7095
7098
|
- **log_probs** (Tensor) - A tensor of shape (T, N, C), where T is input length, N is batch size and C is number
|
|
7096
7099
|
of classes (including blank).
|
|
7097
7100
|
- **targets** (Tensor) - A tensor of shape (N, S), where S is max target length, means the target sequences.
|
|
@@ -7461,43 +7464,6 @@ class Dilation2D(Primitive):
|
|
|
7461
7464
|
self.add_prim_attr('dilation', self.dilation)
|
|
7462
7465
|
|
|
7463
7466
|
|
|
7464
|
-
class SoftShrink(Primitive):
|
|
7465
|
-
r"""
|
|
7466
|
-
Applies the SoftShrink function element-wise.
|
|
7467
|
-
|
|
7468
|
-
Refer to :func:`mindspore.ops.softshrink` for more details.
|
|
7469
|
-
|
|
7470
|
-
Args:
|
|
7471
|
-
lambd(float, optional): The :math:`\lambda` must be no less than zero. Default: ``0.5`` .
|
|
7472
|
-
|
|
7473
|
-
Inputs:
|
|
7474
|
-
- **input_x** (Tensor) - The input of soft shrink with data type of float16 or float32.
|
|
7475
|
-
|
|
7476
|
-
Outputs:
|
|
7477
|
-
Tensor, has the same shape and data type as `input_x`.
|
|
7478
|
-
|
|
7479
|
-
Supported Platforms:
|
|
7480
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
7481
|
-
|
|
7482
|
-
Examples:
|
|
7483
|
-
>>> import mindspore
|
|
7484
|
-
>>> import numpy as np
|
|
7485
|
-
>>> from mindspore import Tensor, ops
|
|
7486
|
-
>>> input_x = Tensor(np.array([[ 0.5297, 0.7871, 1.1754], [ 0.7836, 0.6218, -1.1542]]), mindspore.float16)
|
|
7487
|
-
>>> softshrink = ops.SoftShrink()
|
|
7488
|
-
>>> output = softshrink(input_x)
|
|
7489
|
-
>>> print(output)
|
|
7490
|
-
[[ 0.02979 0.287 0.676 ]
|
|
7491
|
-
[ 0.2837 0.1216 -0.6543 ]]
|
|
7492
|
-
"""
|
|
7493
|
-
|
|
7494
|
-
@prim_attr_register
|
|
7495
|
-
def __init__(self, lambd=0.5):
|
|
7496
|
-
"""Initialize SoftShrink"""
|
|
7497
|
-
validator.check_value_type("lambd", lambd, [float], self.name)
|
|
7498
|
-
validator.check_number("lambd", lambd, 0, validator.GE, self.name)
|
|
7499
|
-
|
|
7500
|
-
|
|
7501
7467
|
class ApplyAdagradDA(Primitive):
|
|
7502
7468
|
r"""
|
|
7503
7469
|
Update `var` according to the proximal adagrad scheme.
|
|
@@ -9591,79 +9557,19 @@ class PromptFlashAttention(Primitive):
|
|
|
9591
9557
|
outputs=["attention_out"])
|
|
9592
9558
|
|
|
9593
9559
|
|
|
9594
|
-
class IncreFlashAttention(Primitive):
|
|
9595
|
-
r"""
|
|
9596
|
-
The interface for fully inference.
|
|
9597
|
-
|
|
9598
|
-
B -- Batch size
|
|
9599
|
-
|
|
9600
|
-
S -- Sequence length
|
|
9601
|
-
|
|
9602
|
-
H -- Hidden size
|
|
9603
|
-
|
|
9604
|
-
.. warning::
|
|
9605
|
-
This is an experimental API that is subject to change or deletion.
|
|
9606
|
-
If there is no input parameter and no default value, None needs to be passed.
|
|
9607
|
-
|
|
9608
|
-
Args:
|
|
9609
|
-
- **num_heads** (int) - The number of heads.
|
|
9610
|
-
- **input_layout** (str) - the data layout of the input qkv, support `(BSH)` and `(BNSD)`. Default `BSH`.
|
|
9611
|
-
- **scale_value** (double) - The scale value indicating the scale coefficient, which is used as the scalar of
|
|
9612
|
-
Muls in the calculation. Default: 1.0.
|
|
9613
|
-
- **num_key_value_heads** (int) - head numbers of key/value which are used in GQA algorithm.
|
|
9614
|
-
The value o indicates if the key and value have the same head nums, use numHeads. Default: 0.
|
|
9615
|
-
- **block_size** (int) - Default: 0.
|
|
9616
|
-
- **inner_precise** (int) - Default: 1.
|
|
9617
|
-
|
|
9618
|
-
Inputs:
|
|
9619
|
-
- **query** (Tensor) - The query tensor with data type of float16 or bfloat16.
|
|
9620
|
-
Input tensor of shape :math:`(B, 1, H)` / :math:`(B, N, 1, D)`.
|
|
9621
|
-
- **key** (TensorList) - The key tensor with data type of float16 or bfloat16.
|
|
9622
|
-
Input tensor of shape :math:`(B, S, H)` / :math:`(B, N, S, D)`.
|
|
9623
|
-
- **value** (TensorList) - The value tensor with data type of float16 or bfloat16.
|
|
9624
|
-
Input tensor of shape :math:`(B, S, H)` / :math:`(B, N, S, D)`.
|
|
9625
|
-
- **attn_mask** (Tensor) - The attention mask tensor with data type of float16 or bool.
|
|
9626
|
-
Input tensor of shape :math:`(B, S)` / :math:`(B, 1, S)` / :math:`(B, 1, 1, S)`.
|
|
9627
|
-
- **actual_seq_lengths** (Tensor) - Describe actual sequence length of each input with data type of int.
|
|
9628
|
-
- **pse_shift** (Tensor) - The position encoding tensor with data type of float16 or float32.
|
|
9629
|
-
- **dequant_scale1** (Tensor) - Quantitative parametor, the tensor with data type of uint64.
|
|
9630
|
-
- **quant_scale1** (Tensor) - Quantitative parametor, the tensor with data type of float.
|
|
9631
|
-
- **dequant_scale2** (Tensor) - Quantitative parametor, the tensor with data type of uint64.
|
|
9632
|
-
- **quant_scale2** (Tensor) - Quantitative parametor, the tensor with data type of float.
|
|
9633
|
-
- **quant_offset2** (Tensor) - Quantitative parametor, the tensor with data type of float.
|
|
9634
|
-
- **antiquant_scale** (Tensor) - Quantitative parametor, the tensor with data type of float.
|
|
9635
|
-
- **antiquant_offset** (Tensor) - Quantitative parametor, the tensor with data type of float.
|
|
9636
|
-
- **block_table** (Tensor) - The tensor with data type of float.
|
|
9637
|
-
|
|
9638
|
-
Outputs:
|
|
9639
|
-
- **attention_out** (Tensor) - Input tensor of shape :math:`(B, 1, H)` / :math:`(B, N, 1, D)`.
|
|
9640
|
-
|
|
9641
|
-
Supported Platforms:
|
|
9642
|
-
``Ascend``
|
|
9643
|
-
"""
|
|
9644
|
-
|
|
9645
|
-
@prim_attr_register
|
|
9646
|
-
def __init__(self, num_heads, input_layout="BSH", scale_value=1.0, num_key_value_heads=0, block_size=0,
|
|
9647
|
-
inner_precise=1):
|
|
9648
|
-
"""Initialize IncreFlashAttention."""
|
|
9649
|
-
validator.check_value_type('num_heads', num_heads, [int], self.name)
|
|
9650
|
-
validator.check_value_type('input_layout', input_layout, [str], self.name)
|
|
9651
|
-
validator.check_value_type('scale_value', scale_value, [float], self.name)
|
|
9652
|
-
validator.check_value_type('num_key_value_heads', num_key_value_heads, [int], self.name)
|
|
9653
|
-
validator.check_value_type('block_size', block_size, [int], self.name)
|
|
9654
|
-
validator.check_value_type('inner_precise', inner_precise, [int], self.name)
|
|
9655
|
-
self.init_prim_io_names(inputs=["query", "key", "value", "attn_mask", "actual_seq_lengths", "pse_shift",
|
|
9656
|
-
"dequant_scale1", "quant_scale1", "dequant_scale2", "quant_scale2",
|
|
9657
|
-
"quant_offset2", "antiquant_scale", "antiquant_offset", "block_table"],
|
|
9658
|
-
outputs=["attention_out"])
|
|
9659
|
-
|
|
9660
|
-
|
|
9661
9560
|
class AllFinite(Primitive):
|
|
9662
9561
|
r"""
|
|
9663
9562
|
Check all gradients is finite.
|
|
9664
9563
|
"""
|
|
9564
|
+
|
|
9665
9565
|
@prim_attr_register
|
|
9666
9566
|
def __init__(self):
|
|
9667
9567
|
"""Initialize"""
|
|
9668
9568
|
self.init_prim_io_names(inputs=['gradients'],
|
|
9669
9569
|
outputs=["is_finite"])
|
|
9570
|
+
if context.get_context("device_target") == "Ascend":
|
|
9571
|
+
checker = AscendEnvChecker(None)
|
|
9572
|
+
if not checker.check_custom_version():
|
|
9573
|
+
raise RuntimeError(
|
|
9574
|
+
"The version of Ascend AI software package installed "
|
|
9575
|
+
"in the current environment does not support AllFinite.")
|
|
@@ -300,10 +300,10 @@ class SampleDistortedBoundingBoxV2(Primitive):
|
|
|
300
300
|
|
|
301
301
|
@prim_attr_register
|
|
302
302
|
def __init__(self, seed=0, seed2=0, \
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
303
|
+
aspect_ratio_range=(0.75, 1.33), \
|
|
304
|
+
area_range=(0.05, 1.0), \
|
|
305
|
+
max_attempts=100, \
|
|
306
|
+
use_image_if_no_bounding_boxes=False):
|
|
307
307
|
validator.check_is_int(seed, "seed", self.name)
|
|
308
308
|
validator.check_is_int(seed2, "seed2", self.name)
|
|
309
309
|
validator.check_value_type("aspect_ratio_range", aspect_ratio_range, [list, tuple], self.name)
|
|
@@ -584,6 +584,9 @@ class StopGradient(Primitive):
|
|
|
584
584
|
pass
|
|
585
585
|
|
|
586
586
|
|
|
587
|
+
stop_gradient_ = StopGradient()
|
|
588
|
+
|
|
589
|
+
|
|
587
590
|
class ConfusionMatrix(PrimitiveWithInfer):
|
|
588
591
|
r"""
|
|
589
592
|
Calculates the confusion matrix from labels and predictions.
|
|
@@ -89,6 +89,10 @@ class TruncatedNormal(Primitive):
|
|
|
89
89
|
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
90
90
|
to worry about which seed is more important.
|
|
91
91
|
|
|
92
|
+
.. warning::
|
|
93
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
94
|
+
the `seed` and `seed2` parameter have no effect.
|
|
95
|
+
|
|
92
96
|
Args:
|
|
93
97
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
94
98
|
must be non-negative. Default: ``0`` .
|
|
@@ -153,6 +157,10 @@ class StandardNormal(Primitive):
|
|
|
153
157
|
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
154
158
|
to worry about which seed is more important.
|
|
155
159
|
|
|
160
|
+
.. warning::
|
|
161
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
162
|
+
the `seed` and `seed2` parameter have no effect.
|
|
163
|
+
|
|
156
164
|
Args:
|
|
157
165
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
158
166
|
must be non-negative. Default: ``0`` .
|
|
@@ -204,6 +212,10 @@ class StandardLaplace(Primitive):
|
|
|
204
212
|
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
205
213
|
to worry about which seed is more important.
|
|
206
214
|
|
|
215
|
+
.. warning::
|
|
216
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
217
|
+
the `seed` and `seed2` parameter have no effect.
|
|
218
|
+
|
|
207
219
|
Args:
|
|
208
220
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
209
221
|
must be non-negative. Default: ``0`` .
|
|
@@ -367,6 +379,10 @@ class Gamma(PrimitiveWithInfer):
|
|
|
367
379
|
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
368
380
|
to worry about which seed is more important.
|
|
369
381
|
|
|
382
|
+
.. warning::
|
|
383
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
384
|
+
the `seed` and `seed2` parameter have no effect.
|
|
385
|
+
|
|
370
386
|
Args:
|
|
371
387
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
372
388
|
must be non-negative. Default: ``0`` .
|
|
@@ -450,6 +466,10 @@ class ParameterizedTruncatedNormal(Primitive):
|
|
|
450
466
|
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
451
467
|
to worry about which seed is more important.
|
|
452
468
|
|
|
469
|
+
.. warning::
|
|
470
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
471
|
+
the `seed` and `seed2` parameter have no effect.
|
|
472
|
+
|
|
453
473
|
Args:
|
|
454
474
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
455
475
|
must be non-negative. Default: ``0`` .
|
|
@@ -672,6 +692,10 @@ class UniformInt(Primitive):
|
|
|
672
692
|
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
673
693
|
to worry about which seed is more important.
|
|
674
694
|
|
|
695
|
+
.. warning::
|
|
696
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
697
|
+
the `seed` and `seed2` parameter have no effect.
|
|
698
|
+
|
|
675
699
|
Args:
|
|
676
700
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
677
701
|
must be non-negative. Default: ``0`` .
|
|
@@ -737,6 +761,10 @@ class UniformReal(Primitive):
|
|
|
737
761
|
- GPU: int32, int64.
|
|
738
762
|
- CPU: int16, int32, int64.
|
|
739
763
|
|
|
764
|
+
.. warning::
|
|
765
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
766
|
+
the `seed` and `seed2` parameter have no effect.
|
|
767
|
+
|
|
740
768
|
Args:
|
|
741
769
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
742
770
|
must be non-negative. Default: ``0`` .
|
|
@@ -837,6 +865,10 @@ class RandomCategorical(PrimitiveWithInfer):
|
|
|
837
865
|
r"""
|
|
838
866
|
Generates random samples from a given categorical distribution tensor.
|
|
839
867
|
|
|
868
|
+
.. warning::
|
|
869
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
870
|
+
the `seed` parameter has no effect.
|
|
871
|
+
|
|
840
872
|
Args:
|
|
841
873
|
dtype (mindspore.dtype): The type of output. Its value must be one of mstype.int16,
|
|
842
874
|
mstype.int32 and mstype.int64. Default: ``mstype.int64`` .
|
|
@@ -903,6 +935,10 @@ class Multinomial(Primitive):
|
|
|
903
935
|
- Using the Philox algorithm to scramble seed and seed2 to obtain random seed so that the user doesn't need
|
|
904
936
|
to worry about which seed is more important.
|
|
905
937
|
|
|
938
|
+
.. warning::
|
|
939
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
940
|
+
the `seed` and `seed2` parameter have no effect.
|
|
941
|
+
|
|
906
942
|
Args:
|
|
907
943
|
seed (int, optional): The operator-level random seed, used to generate random numbers,
|
|
908
944
|
must be non-negative. Default: ``0`` .
|
|
@@ -1012,6 +1048,11 @@ class UniformCandidateSampler(Primitive):
|
|
|
1012
1048
|
|
|
1013
1049
|
Refer to :func:`mindspore.ops.uniform_candidate_sampler` for more details.
|
|
1014
1050
|
|
|
1051
|
+
.. warning::
|
|
1052
|
+
- The Ascend backend does not support the reproducibility of random numbers, so
|
|
1053
|
+
the `seed` parameter has no effect.
|
|
1054
|
+
- The Ascend backend does not support dynamic shape scenarios currently.
|
|
1055
|
+
|
|
1015
1056
|
Args:
|
|
1016
1057
|
num_true (int): The number of target classes in each training example.
|
|
1017
1058
|
num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
|
|
@@ -1026,7 +1067,7 @@ class UniformCandidateSampler(Primitive):
|
|
|
1026
1067
|
|
|
1027
1068
|
Inputs:
|
|
1028
1069
|
- **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of
|
|
1029
|
-
:math:`(batch\_size, num\_true)`.
|
|
1070
|
+
:math:`(batch\_size, num\_true)`. The value range of the elements must be :math:`[0, range\_max)`.
|
|
1030
1071
|
|
|
1031
1072
|
Outputs:
|
|
1032
1073
|
- **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
|
|
@@ -1086,6 +1127,10 @@ class LogUniformCandidateSampler(Primitive):
|
|
|
1086
1127
|
|
|
1087
1128
|
Refer to :func:`mindspore.ops.log_uniform_candidate_sampler` for more details.
|
|
1088
1129
|
|
|
1130
|
+
.. warning::
|
|
1131
|
+
The Ascend backend does not support the reproducibility of random numbers, so
|
|
1132
|
+
the `seed` parameter has no effect.
|
|
1133
|
+
|
|
1089
1134
|
Args:
|
|
1090
1135
|
num_true (int, optional): The number of target classes per training example. Default: ``1`` .
|
|
1091
1136
|
num_sampled (int, optional): The number of classes to randomly sample. Default: ``5`` .
|
mindspore/ops/primitive.py
CHANGED
|
@@ -171,6 +171,9 @@ class Primitive(Primitive_):
|
|
|
171
171
|
if not isinstance(in_value, int) and self.name not in SUPPORTED_TUPLE_IN_TUPLE_STRATEGY:
|
|
172
172
|
raise TypeError(f'The {log_info}: {strategy} of {self.name} is not valid,'
|
|
173
173
|
f' the value of strategy must be int type, but got:{type(in_value)}')
|
|
174
|
+
if isinstance(in_value, Layout) and (self.name in SUPPORTED_TUPLE_IN_TUPLE_STRATEGY):
|
|
175
|
+
is_layout.append(True)
|
|
176
|
+
continue
|
|
174
177
|
is_layout.append(False)
|
|
175
178
|
continue
|
|
176
179
|
is_layout.append(True)
|
|
@@ -188,9 +191,17 @@ class Primitive(Primitive_):
|
|
|
188
191
|
raise TypeError(f'{log_info} must be tuple type, but got:{type(layout)}')
|
|
189
192
|
layout_value = ()
|
|
190
193
|
for in_ele in layout:
|
|
191
|
-
if
|
|
192
|
-
|
|
193
|
-
|
|
194
|
+
if isinstance(in_ele, Layout):
|
|
195
|
+
layout_value += (in_ele.to_dict(),)
|
|
196
|
+
elif isinstance(in_ele, tuple):
|
|
197
|
+
new_layout_list = ()
|
|
198
|
+
for ele in in_ele:
|
|
199
|
+
if not isinstance(ele, Layout):
|
|
200
|
+
raise TypeError(f"The {log_info} item should be a object of class Layout.")
|
|
201
|
+
new_layout_list += (ele.to_dict(),)
|
|
202
|
+
layout_value += (new_layout_list,)
|
|
203
|
+
else:
|
|
204
|
+
raise TypeError(f"The {log_info} item should be a object of class Layout or a tuple.")
|
|
194
205
|
return layout_value
|
|
195
206
|
|
|
196
207
|
def _check_shard_strategy_in_out_match(self, in_strategy, out_strategy):
|
|
@@ -299,7 +310,8 @@ class Primitive(Primitive_):
|
|
|
299
310
|
if out_strategy is not None:
|
|
300
311
|
out_is_layout = self._check_shard_strategy(out_strategy, "out_strategy")
|
|
301
312
|
self._check_shard_strategy_in_out_match(in_strategy, out_strategy)
|
|
302
|
-
if in_is_layout is not None and out_is_layout is not None and
|
|
313
|
+
if in_is_layout is not None and out_is_layout is not None and (
|
|
314
|
+
(in_is_layout[0] != out_is_layout[0]) and (self.name not in SUPPORTED_TUPLE_IN_TUPLE_STRATEGY)):
|
|
303
315
|
raise ValueError(f'The in_strategy type must equal to the out_strategy type, '
|
|
304
316
|
f'one using tuple(tuple) and the other using tuple(Layout) is not allowed.')
|
|
305
317
|
in_layout_value = None
|
|
@@ -549,7 +561,7 @@ class PrimitiveWithCheck(Primitive):
|
|
|
549
561
|
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
|
|
550
562
|
|
|
551
563
|
More on how to customize a Op, please refer to `Custom Operators
|
|
552
|
-
<https://www.mindspore.cn/
|
|
564
|
+
<https://www.mindspore.cn/docs/en/master/model_train/custom_program/op_custom.html>`_.
|
|
553
565
|
|
|
554
566
|
Args:
|
|
555
567
|
name (str): Name of the current Primitive.
|
|
@@ -643,7 +655,7 @@ class PrimitiveWithInfer(Primitive):
|
|
|
643
655
|
logic of the shape and type. The infer_value() is used for constant propagation.
|
|
644
656
|
|
|
645
657
|
More on how to customize a Op, please refer to `Custom Operators
|
|
646
|
-
<https://www.mindspore.cn/
|
|
658
|
+
<https://www.mindspore.cn/docs/en/master/model_train/custom_program/op_custom.html>`_.
|
|
647
659
|
|
|
648
660
|
Args:
|
|
649
661
|
name (str): Name of the current Primitive.
|
|
@@ -23,6 +23,7 @@ import pathlib
|
|
|
23
23
|
import logging
|
|
24
24
|
import gen_utils
|
|
25
25
|
from pyboost_utils import AclnnUtils, get_dtypes
|
|
26
|
+
from gen_constants import MS_OPS_KERNEL_PATH
|
|
26
27
|
|
|
27
28
|
auto_gen = ''
|
|
28
29
|
|
|
@@ -35,7 +36,7 @@ def gen_h(op_name, aclnn_name, op_yaml, kernelmod_h_path, need_update_shape):
|
|
|
35
36
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_{op_name.upper()}_ACLNN{auto_gen.upper()}_KERNEL_MOD_H_
|
|
36
37
|
#include <vector>
|
|
37
38
|
#include "ops/base_operator.h"
|
|
38
|
-
#include "
|
|
39
|
+
#include "{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_kernel_mod.h"
|
|
39
40
|
#include "transform/acl_ir/acl_convert.h"
|
|
40
41
|
"""
|
|
41
42
|
update_shape = f"""
|
|
@@ -78,7 +79,7 @@ def gen_cc(op_name, class_name, op_yaml, kernelmod_cc_path, need_update_shape):
|
|
|
78
79
|
"""generate cc files"""
|
|
79
80
|
kernelmod_name = op_yaml.get('dispatch').get("Ascend")
|
|
80
81
|
cc_head = f"""
|
|
81
|
-
#include "
|
|
82
|
+
#include "{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn{auto_gen}/{op_name}_aclnn_kernel.h"
|
|
82
83
|
#include <algorithm>
|
|
83
84
|
#include <vector>
|
|
84
85
|
#include <memory>
|
|
@@ -135,8 +136,7 @@ bool {kernelmod_name}::Launch(const std::vector<KernelTensor *> &inputs, const s
|
|
|
135
136
|
const std::vector<KernelTensor *> &outputs, void *stream_ptr) {{
|
|
136
137
|
MS_EXCEPTION_IF_NULL(stream_ptr);
|
|
137
138
|
{input_templete}
|
|
138
|
-
|
|
139
|
-
RunOp(stream_ptr, workspace);
|
|
139
|
+
RunOp(stream_ptr, workspace, {inputs});
|
|
140
140
|
return true;
|
|
141
141
|
}}
|
|
142
142
|
"""
|
|
@@ -177,10 +177,10 @@ def gen_aclnn_kernel(op_name, yaml_str, need_update_shape=False, auto=False):
|
|
|
177
177
|
if check_op_registed(op_name) and not auto:
|
|
178
178
|
logging.warning("Kernel {%s} is already registered.", op_name)
|
|
179
179
|
return
|
|
180
|
-
current_path = os.path.dirname(os.path.
|
|
180
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
181
181
|
work_path = os.path.join(current_path, '../../../../')
|
|
182
182
|
|
|
183
|
-
aclnn_path = '
|
|
183
|
+
aclnn_path = '{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn/'
|
|
184
184
|
# merge inner ops
|
|
185
185
|
op_yaml = yaml_str.get(op_name)
|
|
186
186
|
class_name = ''.join(word.capitalize() for word in op_name.split('_'))
|
|
@@ -196,7 +196,7 @@ def gen_aclnn_kernel(op_name, yaml_str, need_update_shape=False, auto=False):
|
|
|
196
196
|
return
|
|
197
197
|
auto_gen = "_auto_gen"
|
|
198
198
|
dispatch['Ascend'] = class_name + "Ascend"
|
|
199
|
-
aclnn_path = '
|
|
199
|
+
aclnn_path = f'{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn_auto_gen/'
|
|
200
200
|
pathlib.Path(os.path.join(work_path, aclnn_path)).mkdir(parents=True, exist_ok=True)
|
|
201
201
|
if dispatch.get("Ascend") is None:
|
|
202
202
|
raise ValueError("KernelMod {} is auto generated. If need achieve it, "
|
|
@@ -208,10 +208,10 @@ def gen_aclnn_kernel(op_name, yaml_str, need_update_shape=False, auto=False):
|
|
|
208
208
|
generate(op_name, class_name, op_yaml, kernelmod_h_and_cc_path, need_update_shape)
|
|
209
209
|
|
|
210
210
|
|
|
211
|
-
def get_registed_ops(file_path='
|
|
211
|
+
def get_registed_ops(file_path=f'{MS_OPS_KERNEL_PATH}/ascend/opapi/'):
|
|
212
212
|
'''get registered ops by search files'''
|
|
213
|
-
# default search in '
|
|
214
|
-
current_path = os.path.dirname(os.path.
|
|
213
|
+
# default search in 'ops/kernel/ascend/opapi/'
|
|
214
|
+
current_path = os.path.dirname(os.path.realpath(__file__))
|
|
215
215
|
work_path = os.path.join(current_path, '../../../../')
|
|
216
216
|
search_path = os.path.join(work_path, file_path)
|
|
217
217
|
ret = []
|
|
@@ -230,7 +230,7 @@ def get_registed_ops(file_path='mindspore/ccsrc/plugin/device/ascend/kernel/opap
|
|
|
230
230
|
|
|
231
231
|
|
|
232
232
|
registed_ops = get_registed_ops()
|
|
233
|
-
manual_registed_ops = get_registed_ops('
|
|
233
|
+
manual_registed_ops = get_registed_ops(f'{MS_OPS_KERNEL_PATH}/ascend/opapi/aclnn/')
|
|
234
234
|
|
|
235
235
|
|
|
236
236
|
def check_op_registed(op_name, manual=False):
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generate constants
|
|
17
|
+
"""
|
|
18
|
+
import os
|
|
19
|
+
|
|
20
|
+
# op_def
|
|
21
|
+
OP_DEF_AUTO_GENERATE_PATH = "op_def/auto_generate"
|
|
22
|
+
MS_OP_DEF_AUTO_GENERATE_PATH = "mindspore/ops/op_def/auto_generate"
|
|
23
|
+
YAML_PATH = "op_def/yaml"
|
|
24
|
+
MS_YAML_PATH = "mindspore/ops/" + YAML_PATH
|
|
25
|
+
PY_AUTO_GEN_PATH = "mindspore/python/mindspore/ops/auto_generate"
|
|
26
|
+
PY_OPS_GEN_PATH = "mindspore/python/mindspore/ops_generate"
|
|
27
|
+
|
|
28
|
+
# infer
|
|
29
|
+
MS_OPS_FUNC_IMPL_PATH = "mindspore/ops/infer/ops_func_impl"
|
|
30
|
+
|
|
31
|
+
# view
|
|
32
|
+
MS_OPS_VIEW_PATH = "mindspore/ops/view"
|
|
33
|
+
|
|
34
|
+
# kernel
|
|
35
|
+
MS_OPS_KERNEL_PATH = "mindspore/ops/kernel"
|
|
36
|
+
MS_COMMON_PYBOOST_KERNEL_PATH = os.path.join(MS_OPS_KERNEL_PATH, "common/pyboost")
|