mindspore 2.2.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_akg/akg/composite/build_module.py +104 -20
  3. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  4. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  5. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  6. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  7. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  8. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  9. mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
  10. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  11. mindspore/_akg/akg/utils/kernel_exec.py +41 -15
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
  13. mindspore/_akg/akg/utils/util.py +56 -1
  14. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +3 -3
  17. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  18. mindspore/_extends/graph_kernel/splitter.py +3 -2
  19. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  20. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  22. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  23. mindspore/_extends/parse/__init__.py +3 -2
  24. mindspore/_extends/parse/parser.py +6 -1
  25. mindspore/_extends/parse/standard_method.py +14 -11
  26. mindspore/_extends/remote/kernel_build_server.py +2 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  28. mindspore/bin/cache_admin +0 -0
  29. mindspore/bin/cache_server +0 -0
  30. mindspore/common/_utils.py +16 -0
  31. mindspore/common/api.py +1 -1
  32. mindspore/common/auto_dynamic_shape.py +81 -85
  33. mindspore/common/dump.py +1 -1
  34. mindspore/common/tensor.py +3 -20
  35. mindspore/config/op_info.config +1 -1
  36. mindspore/context.py +11 -4
  37. mindspore/dataset/engine/cache_client.py +8 -5
  38. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  39. mindspore/dataset/vision/transforms.py +21 -21
  40. mindspore/experimental/optim/adam.py +1 -1
  41. mindspore/gen_ops.py +1 -1
  42. mindspore/include/api/model.h +17 -0
  43. mindspore/include/api/status.h +8 -3
  44. mindspore/lib/libdnnl.so.2 +0 -0
  45. mindspore/lib/libmindspore.so +0 -0
  46. mindspore/lib/libmindspore_backend.so +0 -0
  47. mindspore/lib/libmindspore_common.so +0 -0
  48. mindspore/lib/libmindspore_core.so +0 -0
  49. mindspore/lib/libmindspore_glog.so.0 +0 -0
  50. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  51. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  52. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  53. mindspore/lib/libmindspore_shared_lib.so +0 -0
  54. mindspore/lib/libnnacl.so +0 -0
  55. mindspore/lib/libopencv_core.so.4.5 +0 -0
  56. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  57. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  58. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  59. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  60. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  61. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  62. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  63. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  64. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  65. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  66. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  67. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  68. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  71. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/libakg.so +0 -0
  76. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  77. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  78. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  79. mindspore/lib/plugin/cpu/libakg.so +0 -0
  80. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  81. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  82. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  83. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  84. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  85. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  86. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  87. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  88. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  89. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  90. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  91. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  92. mindspore/nn/cell.py +0 -3
  93. mindspore/nn/layer/activation.py +4 -5
  94. mindspore/nn/layer/conv.py +39 -23
  95. mindspore/nn/layer/flash_attention.py +54 -129
  96. mindspore/nn/layer/math.py +3 -7
  97. mindspore/nn/layer/rnn_cells.py +5 -5
  98. mindspore/nn/wrap/__init__.py +4 -2
  99. mindspore/nn/wrap/cell_wrapper.py +12 -3
  100. mindspore/numpy/utils_const.py +5 -5
  101. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  102. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  103. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  104. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  105. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  106. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  107. mindspore/ops/_utils/utils.py +2 -0
  108. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  109. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  110. mindspore/ops/function/array_func.py +10 -7
  111. mindspore/ops/function/grad/grad_func.py +0 -1
  112. mindspore/ops/function/nn_func.py +98 -9
  113. mindspore/ops/function/random_func.py +2 -1
  114. mindspore/ops/op_info_register.py +24 -21
  115. mindspore/ops/operations/__init__.py +6 -2
  116. mindspore/ops/operations/_grad_ops.py +25 -6
  117. mindspore/ops/operations/_inner_ops.py +155 -23
  118. mindspore/ops/operations/array_ops.py +9 -7
  119. mindspore/ops/operations/comm_ops.py +2 -2
  120. mindspore/ops/operations/custom_ops.py +85 -68
  121. mindspore/ops/operations/inner_ops.py +26 -3
  122. mindspore/ops/operations/math_ops.py +7 -6
  123. mindspore/ops/operations/nn_ops.py +193 -49
  124. mindspore/parallel/_parallel_serialization.py +10 -3
  125. mindspore/parallel/_tensor.py +4 -1
  126. mindspore/parallel/checkpoint_transform.py +13 -2
  127. mindspore/parallel/shard.py +17 -10
  128. mindspore/profiler/common/util.py +1 -0
  129. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  130. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  131. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  132. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  133. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  134. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  135. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  136. mindspore/profiler/parser/framework_parser.py +1 -1
  137. mindspore/profiler/parser/profiler_info.py +19 -0
  138. mindspore/profiler/profiling.py +46 -24
  139. mindspore/rewrite/api/pattern_engine.py +1 -1
  140. mindspore/rewrite/parsers/for_parser.py +7 -7
  141. mindspore/rewrite/parsers/module_parser.py +4 -4
  142. mindspore/rewrite/symbol_tree.py +1 -4
  143. mindspore/run_check/_check_version.py +5 -3
  144. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  145. mindspore/scipy/ops.py +55 -5
  146. mindspore/scipy/optimize/__init__.py +3 -2
  147. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  148. mindspore/train/callback/_summary_collector.py +1 -1
  149. mindspore/train/dataset_helper.py +1 -0
  150. mindspore/train/model.py +2 -2
  151. mindspore/train/serialization.py +97 -11
  152. mindspore/train/summary/_summary_adapter.py +1 -1
  153. mindspore/train/summary/summary_record.py +23 -7
  154. mindspore/version.py +1 -1
  155. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  156. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
  157. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  158. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  159. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  160. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  161. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  162. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  163. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  164. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  165. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  166. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  167. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  168. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  169. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  170. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,7 @@ from mindspore.ops.operations._scalar_ops import bit_or, bit_and
26
26
  from mindspore.ops.operations.comm_ops import ReduceOp
27
27
  from mindspore.ops import signature as sig
28
28
  from mindspore.ops.operations.math_ops import _infer_shape_reduce
29
- from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive,\
29
+ from mindspore.ops.primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register, Primitive, \
30
30
  _run_op, _check_contains_variable
31
31
  from mindspore._c_expression import Tensor as Tensor_
32
32
  from mindspore._c_expression import typing
@@ -167,6 +167,7 @@ class Quant(PrimitiveWithInfer):
167
167
  self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name)
168
168
  self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"],
169
169
  "round_mode", self.name)
170
+ self.add_prim_attr("dst_type", mstype.int8)
170
171
 
171
172
  def infer_shape(self, x_shape):
172
173
  return x_shape
@@ -174,7 +175,7 @@ class Quant(PrimitiveWithInfer):
174
175
  def infer_dtype(self, x_type):
175
176
  validator.check_subclass("input_x", x_type, mstype.tensor_type, self.name)
176
177
  validator.check_type_name("input_x", x_type, [mstype.float16, mstype.float32], self.name)
177
- return mstype.int8
178
+ return self.get_attr_dict()['dst_type']
178
179
 
179
180
 
180
181
  class Lamb(PrimitiveWithInfer):
@@ -491,7 +492,7 @@ class Receive(PrimitiveWithInfer):
491
492
  self.dtype = dtype
492
493
  self.group = group
493
494
  self.add_prim_attr("no_eliminate", True)
494
- valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
495
+ valid_type = [mstype.float16, mstype.bfloat16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
495
496
  args = {"dtype": dtype}
496
497
  validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
497
498
 
@@ -2146,13 +2147,14 @@ class ClipByNorm(PrimitiveWithInfer):
2146
2147
  @prim_attr_register
2147
2148
  def __init__(self, axis=None):
2148
2149
  """Initialize ClipByNorm"""
2150
+ self.axis_str = 'axis'
2149
2151
  self.axis = () if axis is None else axis
2150
- validator.check_value_type('axis', self.axis, [int, tuple, list], self.name)
2152
+ validator.check_value_type(self.axis_str, self.axis, [int, tuple, list], self.name)
2151
2153
  axis_check = self.axis if isinstance(self.axis, Iterable) else (self.axis,)
2152
2154
  for i, value in enumerate(axis_check):
2153
2155
  validator.check_value_type('axis[%d]' % i, value, [int], self.name)
2154
- self.init_attrs['axis'] = self.axis
2155
- self.add_prim_attr('axis', self.axis)
2156
+ self.init_attrs[self.axis_str] = self.axis
2157
+ self.add_prim_attr(self.axis_str, self.axis)
2156
2158
  self.init_prim_io_names(inputs=['x', 'clip_norm'], outputs=['output'])
2157
2159
 
2158
2160
  def infer_shape(self, x_shape, clip_norm_shape):
@@ -2729,27 +2731,29 @@ class CopyWithSlice(Primitive):
2729
2731
  self.init_prim_io_names(inputs=['x', 'y'], outputs=['x'])
2730
2732
 
2731
2733
 
2732
- class MoeFFN(Primitive):
2734
+ class FFN(Primitive):
2733
2735
  r"""
2734
- The MoeFFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2736
+ The FFN computation is similar to Feed-Forward Network, it contains matmul + gelu + matmul.
2735
2737
 
2736
2738
  Args:
2737
2739
  activation (string): The activation type, set to 'fastgelu' or 'gelu'.
2738
- Only support 'fastgelu' for now. Default: "fastgelu".
2740
+ Only support 'fastgelu' for now. Default: "fastgelu".
2741
+ inner_precise (int): The precise mode, set to 0 for high precision or 1 for high performance.
2742
+ Only support 1 for now. Default: 0.
2739
2743
 
2740
2744
  Inputs:
2741
2745
  - **x** (Tensor) - The input tensor with data type of int8, float16.
2742
2746
  Input tensor of shape :math:`(batch\_size * seq\_length, hidden\_size)`.
2747
+ - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2748
+ Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2749
+ - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2750
+ Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2743
2751
  - **expert_tokens** (Tensor]) - The expert tokens tensor with data type of int64.
2744
2752
  Expert tokens tensor of shape :math:`(16,)`. For example, `(2, 1, 0, .., 9)`
2745
2753
  indicate that the 0th expert deals with 2 tokens, the 1th expert deals with 1 tokens,
2746
2754
  the 2th expert do noting and so on.
2747
- - **weight1** (Tensor) - The weight1 tensor with data type of float16.
2748
- Weight1 tensor of shape :math:`(expert\_num, hidden\_size, ffn\_hidden\_size)`.
2749
2755
  - **bias1** (Tensor) - The bias1 tensor with data type of float16.
2750
2756
  Bias1 tensor of shape :math:`(expert\_num, ffn\_hidden\_size)`.
2751
- - **weight2** (Tensor) - The weight2 tensor with data type of float16.
2752
- Weight2 tensor of shape :math:`(expert\_num, ffn\_hidden\_size, hidden\_size)`.
2753
2757
  - **bias2** (Tensor) - The bias2 tensor with data type of float16.
2754
2758
  Bias2 tensor of shape :math:`(expert\_num, hidden\_size)`.
2755
2759
  - **scale** (Tensor) - The scale tensor with data type of float16. Not enable now.
@@ -2771,21 +2775,149 @@ class MoeFFN(Primitive):
2771
2775
  >>> h_f = 4 * h
2772
2776
  >>> e = 16
2773
2777
  >>> x = Tensor(np.random.randn(b * s, h).astype(np.float16))
2774
- >>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
2775
2778
  >>> w1 = Tensor(np.random.randn(e, h, h_f).astype(np.float16))
2776
- >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2777
2779
  >>> w2 = Tensor(np.random.randn(e, h_f, h).astype(np.float16))
2780
+ >>> expert_tokens = Tensor(np.random.randn(e).astype(np.int64))
2781
+ >>> bias1 = Tensor(np.random.randn(e, h_f).astype(np.float16))
2778
2782
  >>> bias2 = Tensor(np.random.randn(e, h).astype(np.float16))
2779
- >>> moe_ffn = _inner_ops.MoeFFN("fastgelu")
2780
- >>> output = moe_ffn(x, w1, bias1, w2, bias2)
2783
+ >>> ffn = _inner_ops.FFN("fastgelu", 1)
2784
+ >>> output = ffn(x, w1, w2, expert_tokens, bias1, bias2)
2781
2785
  >>> print(output)
2782
2786
  """
2783
2787
 
2784
2788
  @prim_attr_register
2785
- def __init__(self, activation):
2786
- """Initialize MoeFFN."""
2787
- self.init_prim_io_names(inputs=["x", "expert_tokens", "weight1", "bias1",
2788
- "weight2", "bias2", "scale", "offset", "deq_scale1"
2789
- "deq_scale2"],
2789
+ def __init__(self, activation, inner_precise):
2790
+ """Initialize FFN."""
2791
+ self.init_prim_io_names(inputs=["x", "weight1", "weight2", "expert_tokens", "bias1",
2792
+ "bias2", "scale", "offset", "deq_scale1", "deq_scale2"],
2790
2793
  outputs=["y"])
2791
- self.activation = activation
2794
+ cls_name = self.name
2795
+ validator.check_value_type("activation", activation, [str], cls_name)
2796
+ validator.check_value_type("inner_precise", inner_precise, [int], cls_name)
2797
+
2798
+
2799
+ class DecoderKVCache(Primitive):
2800
+ r"""
2801
+ The DecoderKVCache is used for decoding the KVCache of transformer network.
2802
+
2803
+ Args:
2804
+ cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
2805
+ When seq_len_axis is 2, cache tensor of shape
2806
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
2807
+ When seq_len_axis is 1, cache tensor of shape
2808
+ :math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
2809
+ update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
2810
+ When seq_len_axis is 2, update tensor of shape
2811
+ :math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
2812
+ When seq_len_axis is 1, update tensor of shape
2813
+ :math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
2814
+ valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
2815
+ Valid_seq_len tensor of shape :math:`(batch\_size)`.
2816
+ batch_index (Tensor): The batch_index tensor with data type of int64.
2817
+ Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
2818
+ seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
2819
+ new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2820
+ New_max_seq_len tensor of shape :math:`(1)`.
2821
+ Indicate that user want to change the shape of cache tensor from
2822
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
2823
+ :math:
2824
+ `(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
2825
+ to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
2826
+ cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2827
+ Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
2828
+
2829
+ Outputs:
2830
+ With same data type and same shape as `cache` tensor.
2831
+
2832
+ Supported Platforms:
2833
+ ``Ascend``
2834
+
2835
+ Examples:
2836
+ >>> from mindspore.ops.operations import _inner_ops
2837
+ >>> b = 4
2838
+ >>> h = 40
2839
+ >>> max_s = 1024
2840
+ >>> s = 1
2841
+ >>> d = 128
2842
+ >>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
2843
+ >>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
2844
+ >>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
2845
+ >>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
2846
+ >>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2847
+ >>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2848
+ >>> decoder_kv_cache = _inner_ops.DecoderKVCache()
2849
+ >>> output = decoder_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
2850
+ >>> print(cache)
2851
+ """
2852
+ @prim_attr_register
2853
+ def __init__(self):
2854
+ """Initialize DecoderKVCache."""
2855
+ self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
2856
+ "new_max_seq_len", "cur_max_seq_len"],
2857
+ outputs=["out"])
2858
+ self.add_prim_attr('side_effect_mem', True)
2859
+
2860
+
2861
+ class PromptKVCache(Primitive):
2862
+ r"""
2863
+ The PromptKVCache is used for prefill the KVCache of transformer network.
2864
+
2865
+ Args:
2866
+ cache (Tensor): The cahe tensor with data type of int8, uint8, int16, uint16, float16, float32 and int32.
2867
+ When seq_len_axis is 2, cache tensor of shape
2868
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)`.
2869
+ When seq_len_axis is 1, cache tensor of shape
2870
+ :math:`(batch\_size, max\_seq\_length, num_head, hidden\_size)`.
2871
+ update (Tensor]): The tensor which is used to update the cache tensor. Same data type as cache tensor.
2872
+ When seq_len_axis is 2, update tensor of shape
2873
+ :math:`(batch\_size, num_head, update\_seq\_length, hidden\_size)`.
2874
+ When seq_len_axis is 1, update tensor of shape
2875
+ :math:`(batch\_size, update\_seq\_length, num_head, hidden\_size)`.
2876
+ valid_seq_len (Tensor): The valid_seq_len tensor with data type of int64.
2877
+ Valid_seq_len tensor of shape :math:`(batch\_size)`.
2878
+ batch_index (Tensor): The batch_index tensor with data type of int64.
2879
+ Batch_index tensor of shape :math:`(1)`. Indicate that which batch of cache tensor is going to be update.
2880
+ seq_len_axis (int64): The seq_len_axis indicate which axis is seq_eln, set to '1' or '2'. Default: "2".
2881
+ new_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2882
+ New_max_seq_len tensor of shape :math:`(1)`.
2883
+ Indicate that user want to change the shape of cache tensor from
2884
+ :math:`(batch\_size, num_head, max\_seq\_length, hidden\_size)` to
2885
+ :math:
2886
+ `(batch\_size * max\_seq\_length / new\_max\_seq\_length, num_head, new\_max\_seq\_length, hidden\_size)`
2887
+ to update the cache tensor. This will not real change the shape of `cache` tensor. Not able for now.
2888
+ cur_max_seq_len (Tensor): The new_max_seq_len tensor with data type of int64.
2889
+ Cur_max_seq_len tensor of shape :math:`(1)`. Keep the current seq_len of cache tensor. Not abel for now.
2890
+ align_mode (int64): indicate which axis is seq_eln, 0 is 'right', 1 is 'left'. Default: 0.
2891
+
2892
+ Outputs:
2893
+ With same data type and same shape as `cache` tensor.
2894
+
2895
+ Supported Platforms:
2896
+ ``Ascend``
2897
+
2898
+ Examples:
2899
+ >>> from mindspore import Tensor
2900
+ >>> from mindspore.ops.operations import _inner_ops
2901
+ >>> b = 4
2902
+ >>> h = 40
2903
+ >>> max_s = 1024
2904
+ >>> s = 256
2905
+ >>> d = 128
2906
+ >>> cache = Tensor(np.random.randn(b, h, max_s, d).astype(np.float16))
2907
+ >>> update = Tensor(np.random.randn(b, h, s, d).astype(np.float16))
2908
+ >>> valid_seq_len = Tensor(np.random.randn(b).astype(np.int64))
2909
+ >>> batch_index = Tensor(np.random.randn(1).astype(np.int64))
2910
+ >>> new_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2911
+ >>> cur_max_seq_len = Tensor(np.random.randn(1).astype(np.int64))
2912
+ >>> prompt_kv_cache = _inner_ops.PromptKVCache(0)
2913
+ >>> output = prompt_kv_cache(cache, update, valid_seq_len, batch_index, 2, new_max_seq_len, cur_max_seq_len)
2914
+ >>> print(cache)
2915
+ """
2916
+ @prim_attr_register
2917
+ def __init__(self, padding_mode="right"):
2918
+ """Initialize PromptKVCache."""
2919
+ self.init_prim_io_names(inputs=["cache", "update", "valid_seq_len", "batch_index", "seq_len_axis",
2920
+ "new_max_seq_len", "cur_max_seq_len"],
2921
+ outputs=["out"])
2922
+ self.add_prim_attr('side_effect_mem', True)
2923
+ self.padding_mode = padding_mode
@@ -1208,7 +1208,7 @@ class UniqueWithPad(Primitive):
1208
1208
 
1209
1209
 
1210
1210
  class Split(Primitive):
1211
- """
1211
+ r"""
1212
1212
  Splits the input tensor into output_num of tensors along the given axis and output numbers.
1213
1213
 
1214
1214
  Refer to :func:`mindspore.ops.split` for more details.
@@ -1222,7 +1222,7 @@ class Split(Primitive):
1222
1222
 
1223
1223
  Outputs:
1224
1224
  tuple[Tensor], the shape of each output tensor is the same, which is
1225
- :math:`(x_0, x_1, ..., x_{axis}/{output_num}, ..., x_{R-1})`.
1225
+ :math:`(x_0, x_1, ..., x_{axis}/{output\_num}, ..., x_{R-1})`.
1226
1226
  And the data type is the same as `input_x`.
1227
1227
 
1228
1228
  Supported Platforms:
@@ -1763,16 +1763,18 @@ class FillV2(PrimitiveWithCheck):
1763
1763
  self.init_prim_io_names(inputs=['shape', 'value'], outputs=['y'])
1764
1764
 
1765
1765
  def check_elim(self, dims, x):
1766
- if x is None or (not isinstance(x, (Tensor, Tensor_))) or (x.shape != ()) or\
1767
- dims is None or (isinstance(dims, (tuple, list)) and dims) or\
1768
- isinstance(dims, (Tensor, Tensor_)):
1766
+ x_is_invalid = x is None or (not isinstance(x, (Tensor, Tensor_))) or (x.shape != ())
1767
+ dims_is_invalid = dims is None or (isinstance(dims, (tuple, list)) and dims) or\
1768
+ isinstance(dims, (Tensor, Tensor_))
1769
+ if x_is_invalid or dims_is_invalid:
1769
1770
  return (False, None)
1770
1771
  return (True, x)
1771
1772
 
1772
1773
  def infer_value(self, dims, x):
1773
- if x is None or dims is None or\
1774
+ dims_is_invalid = dims is None or\
1774
1775
  (isinstance(dims, (tuple, list)) and dims) or\
1775
- isinstance(dims, (Tensor, Tensor_)):
1776
+ isinstance(dims, (Tensor, Tensor_))
1777
+ if x is None or dims_is_invalid:
1776
1778
  return None
1777
1779
  return x
1778
1780
 
@@ -94,7 +94,7 @@ class ReduceOp:
94
94
 
95
95
  def check_collective_target_dtype(data_name, data_dtype, prim_name):
96
96
  """Check if data type is valid."""
97
- default_target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)
97
+ default_target_dtypes = (mstype.int8, mstype.uint8, mstype.int32, mstype.float16, mstype.bfloat16, mstype.float32)
98
98
  gpu_target_dtypes = (mstype.bool_, mstype.int8, mstype.int32, mstype.int64, mstype.uint32, mstype.uint64,
99
99
  mstype.float16, mstype.float32, mstype.float64)
100
100
 
@@ -1310,4 +1310,4 @@ class _GetTensorSlice(PrimitiveWithInfer):
1310
1310
  from mindspore.parallel._tensor import _load_tensor
1311
1311
  validator.check_value_type("dev_mat", dev_mat, [tuple], self.name)
1312
1312
  validator.check_value_type("tensor_map", tensor_map, [tuple], self.name)
1313
- return Tensor(_load_tensor(x, dev_mat, tensor_map))
1313
+ return Tensor(_load_tensor(x, dev_mat, tensor_map), x.dtype)
@@ -42,6 +42,24 @@ from ._pyfunc_registry import add_pyfunc
42
42
  if platform.system() != "Windows":
43
43
  import fcntl
44
44
 
45
+ KEY_ATTR = "attr"
46
+ KEY_NAME = "name"
47
+ INPUT_NAMES = "input_names"
48
+ ATTR_NAMES = "attr_names"
49
+ AUTO_DIFF = "autodiff"
50
+ IMPLY_TYPE = "imply_type"
51
+ FUSION_TYPE = "fusion_type"
52
+ MS_KERNEL_FLAG = "ms_kernel_flag"
53
+ AKG = "AKG"
54
+ TBE = "TBE"
55
+ CUDA = "CUDA"
56
+ AICORE = "AiCore"
57
+ CPU = "CPU"
58
+ GPU = "GPU"
59
+ ASCEND = "Ascend"
60
+ HYBRID_TYPE = "hybrid"
61
+ OP_NAME = "op_name"
62
+
45
63
 
46
64
  def _get_cache_path():
47
65
  """
@@ -150,7 +168,6 @@ class Custom(ops.PrimitiveWithInfer):
150
168
 
151
169
  .. warning::
152
170
  - This is an experimental API that is subject to change.
153
- - Currently, the functionality of Custom does not support Ascend 910B.
154
171
 
155
172
  .. note::
156
173
  The supported platforms are determined by the input `func_type`. The supported platforms are as follows:
@@ -453,10 +470,10 @@ class Custom(ops.PrimitiveWithInfer):
453
470
  op_path_in_cache = [] # Save paths for op functions created in the cached.
454
471
  custom_aot_warning = True # Flag to enable warnings about custom aot path white list
455
472
 
456
- def __init__(self, func, out_shape=None, out_dtype=None, func_type="hybrid", bprop=None, reg_info=None):
457
- ops.PrimitiveWithInfer.__init__(self, "Custom")
473
+ def __init__(self, func, out_shape=None, out_dtype=None, func_type=HYBRID_TYPE, bprop=None, reg_info=None):
474
+ super().__init__("Custom")
458
475
 
459
- self.supported_targets = ["Ascend", "GPU", "CPU"]
476
+ self.supported_targets = [ASCEND, GPU, CPU]
460
477
  self.supported_func_type = ["hybrid", "akg", "tbe", "aicpu", "aot", "pyfunc", "julia"]
461
478
  self.log_prefix = "For '{}', 'func_type': {}, 'func': {}".format(self.name, func_type, func)
462
479
  self.func = func
@@ -473,7 +490,7 @@ class Custom(ops.PrimitiveWithInfer):
473
490
  self._update_func_info(reg_info)
474
491
  self.add_prim_attr("func_name", self.func_name)
475
492
  self.add_prim_attr("uniq_name", self.uniq_name)
476
- if self.func_type == "hybrid":
493
+ if self.func_type == HYBRID_TYPE:
477
494
  self.add_prim_attr("func_compile_attrs", self._func_compile_attrs)
478
495
 
479
496
  self.add_prim_attr("imply_path", self.imply_path)
@@ -502,7 +519,7 @@ class Custom(ops.PrimitiveWithInfer):
502
519
  if func_type == "akg":
503
520
  self._set_akg_kernel_type()
504
521
 
505
- if not self.bprop and self.func_type == "hybrid":
522
+ if not self.bprop and self.func_type == HYBRID_TYPE:
506
523
  self._hybrid_autodiff(func_type)
507
524
 
508
525
  self.add_prim_attr("func_type", self.func_type)
@@ -577,7 +594,7 @@ class Custom(ops.PrimitiveWithInfer):
577
594
  elif "compute" in self.func_source_str:
578
595
  self.func_type = "tvm_compute"
579
596
  else:
580
- self.func_type = "hybrid"
597
+ self.func_type = HYBRID_TYPE
581
598
  self._hybrid_func_analyser()
582
599
 
583
600
  def _check_julia_func(self):
@@ -633,18 +650,18 @@ class Custom(ops.PrimitiveWithInfer):
633
650
 
634
651
  elif self.func_type == "julia":
635
652
  self._check_julia_func()
636
- elif self.func_type == "hybrid":
637
- if not hasattr(self.func, "ms_kernel_flag"):
653
+ elif self.func_type == HYBRID_TYPE:
654
+ if not hasattr(self.func, MS_KERNEL_FLAG):
638
655
  raise TypeError("{}, 'func' must be a function decorated by kernel".format(self.log_prefix))
639
656
  self._is_ms_kernel = True
640
657
  self._func_compile_attrs = getattr(self.func, "compile_attrs", {})
641
658
  elif self.func_type == "akg":
642
- if hasattr(self.func, "ms_kernel_flag"):
659
+ if hasattr(self.func, MS_KERNEL_FLAG):
643
660
  logger.warning("{}. To have a better user experience, the mode hybrid is suggested "
644
661
  "for the input function with decorator @kernel. "
645
662
  "To enable this mode, set the 'func_type' to be \"hybrid\"".format(self.log_prefix))
646
663
  elif self.func_type == "pyfunc":
647
- if hasattr(self.func, "ms_kernel_flag"):
664
+ if hasattr(self.func, MS_KERNEL_FLAG):
648
665
  logger.warning("{}. Now you are using the function with decorator @kernel in the mode pyfunc. "
649
666
  "The kernel will be executed as a native python function, which might lead to "
650
667
  "low efficiency. To accelerate the kernel, set the 'func_type' to be \"hybrid\""
@@ -758,7 +775,7 @@ class Custom(ops.PrimitiveWithInfer):
758
775
  continue
759
776
  if isinstance(reg_info_item, str):
760
777
  reg_info_item = json.loads(reg_info_item)
761
- prefix = "_".join([prefix, reg_info_item.get("op_name", "")])
778
+ prefix = "_".join([prefix, reg_info_item.get(OP_NAME, "")])
762
779
  self.uniq_name = prefix + "_" + self.func_name
763
780
  else:
764
781
  raise TypeError("For '{}', 'func' must be of type function or str, but got {}"
@@ -768,23 +785,23 @@ class Custom(ops.PrimitiveWithInfer):
768
785
  """Update op attrs in reg_info."""
769
786
  output_name_list = []
770
787
  for _, item in enumerate(reg_info.get("outputs", [])):
771
- if isinstance(item, dict) and item.get("name"):
772
- output_name_list.append(item.get("name"))
788
+ if isinstance(item, dict) and item.get(KEY_NAME):
789
+ output_name_list.append(item.get(KEY_NAME))
773
790
  if output_name_list:
774
791
  self.add_prim_attr("output_names", output_name_list)
775
792
 
776
- if isinstance(reg_info.get("op_name"), str):
777
- self.add_prim_attr("reg_op_name", reg_info.get("op_name"))
793
+ if isinstance(reg_info.get(OP_NAME), str):
794
+ self.add_prim_attr("reg_op_name", reg_info.get(OP_NAME))
778
795
 
779
796
  if self.func_type == "aicpu":
780
- self.uniq_name = reg_info["op_name"]
797
+ self.uniq_name = reg_info[OP_NAME]
781
798
  self.add_prim_attr("uniq_name", self.uniq_name)
782
799
 
783
800
  if self.func_type in ["aot", "aicpu"]:
784
- if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
785
- for item in reg_info["attr"]:
801
+ if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
802
+ for item in reg_info[KEY_ATTR]:
786
803
  if isinstance(item, dict) and item.get("value") is not None:
787
- self.add_prim_attr(item["name"], item["value"])
804
+ self.add_prim_attr(item[KEY_NAME], item["value"])
788
805
 
789
806
  def _register_info(self, info):
790
807
  """Register reg_info."""
@@ -802,7 +819,7 @@ class Custom(ops.PrimitiveWithInfer):
802
819
  if isinstance(reg_info, str):
803
820
  reg_info = json.loads(reg_info)
804
821
  if self.fake_output:
805
- reg_info["outputs"].append(dict({"index": 0, "name": "y", "param_type": "required"}))
822
+ reg_info["outputs"].append(dict({"index": 0, KEY_NAME: "y", "param_type": "required"}))
806
823
  new_dtype_format = []
807
824
  for i in reg_info["dtype_format"]:
808
825
  new_dtype_format.append(i + (DataType.I32_Default,))
@@ -874,16 +891,16 @@ class Custom(ops.PrimitiveWithInfer):
874
891
  "'CustomRegOp' to generate the registration information, then pass it to 'reg_info' or "
875
892
  "use 'custom_info_register' to bind it to 'func' if 'func' is a function."
876
893
  .format(self.log_prefix, reg_info, type(reg_info)))
877
- reg_info["op_name"] = self.uniq_name
878
- reg_info["imply_type"] = self._get_imply_type(reg_info, target)
879
- if not isinstance(reg_info.get("fusion_type"), str) or not reg_info["fusion_type"].strip():
880
- reg_info["fusion_type"] = "OPAQUE"
894
+ reg_info[OP_NAME] = self.uniq_name
895
+ reg_info[IMPLY_TYPE] = self._get_imply_type(reg_info, target)
896
+ if not isinstance(reg_info.get(FUSION_TYPE), str) or not reg_info[FUSION_TYPE].strip():
897
+ reg_info[FUSION_TYPE] = "OPAQUE"
881
898
  # Supplement necessary info for TBE if these information is missing in reg_info
882
- if reg_info["imply_type"] == "TBE":
883
- if reg_info.get("attr") is not None and isinstance(reg_info["attr"], list):
884
- for i, item in enumerate(reg_info["attr"]):
899
+ if reg_info[IMPLY_TYPE] == TBE:
900
+ if reg_info.get(KEY_ATTR) is not None and isinstance(reg_info[KEY_ATTR], list):
901
+ for i, item in enumerate(reg_info[KEY_ATTR]):
885
902
  if isinstance(item, dict) and item.get("value") is None:
886
- reg_info["attr"][i]["value"] = "all"
903
+ reg_info[KEY_ATTR][i]["value"] = "all"
887
904
  reg_info["async_flag"] = reg_info.get("async_flag", False)
888
905
  reg_info["binfile"] = "%s.so" % self.func_name
889
906
  reg_info["compute_cost"] = reg_info.get("compute_cost", 10)
@@ -891,8 +908,8 @@ class Custom(ops.PrimitiveWithInfer):
891
908
  reg_info["partial_flag"] = reg_info.get("partial_flag", True)
892
909
  reg_info["needCheckSupport"] = reg_info.get("need_check_supported", False)
893
910
  # Supplement necessary info for AKG if these information is missing in reg_info
894
- if reg_info["imply_type"] == "AKG":
895
- target_to_processor = {"Ascend": "AiCore", "GPU": "CUDA", "CPU": "CPU"}
911
+ if reg_info[IMPLY_TYPE] == AKG:
912
+ target_to_processor = {ASCEND: AICORE, GPU: CUDA, CPU: CPU}
896
913
  reg_info["processor"] = reg_info.get("processor", target_to_processor.get(target))
897
914
  return reg_info
898
915
 
@@ -905,15 +922,15 @@ class Custom(ops.PrimitiveWithInfer):
905
922
  # Infer target from reg_info["processor"], reg_info generated from AkgGpuRegOp or AkgAscendRegOp
906
923
  # will have the processor information.
907
924
  if target not in self.supported_targets:
908
- processor_to_target = {"AiCore": "Ascend", "CUDA": "GPU", "CPU": "CPU"}
925
+ processor_to_target = {AICORE: ASCEND, CUDA: GPU, CPU: CPU}
909
926
  target = processor_to_target.get(reg_info.get("processor"))
910
- # Infer target from reg_info["imply_type"]
927
+ # Infer target from reg_info[IMPLY_TYPE]
911
928
  if target not in self.supported_targets:
912
- imply_type_to_target = {"TBE": "Ascend", "GPU": "GPU", "CPU": "CPU"}
913
- target = imply_type_to_target.get(reg_info.get("imply_type"))
929
+ imply_type_to_target = {TBE: ASCEND, GPU: GPU, CPU: CPU}
930
+ target = imply_type_to_target.get(reg_info.get(IMPLY_TYPE))
914
931
  # Infer target from func_type
915
932
  if target not in self.supported_targets:
916
- func_type_to_target = {"tbe": "Ascend", "pyfunc": "CPU"}
933
+ func_type_to_target = {"tbe": ASCEND, "pyfunc": CPU}
917
934
  target = func_type_to_target.get(self.func_type)
918
935
  if target not in self.supported_targets:
919
936
  raise ValueError("{}, target set in registration information must be one of {}, but got {}"
@@ -922,14 +939,14 @@ class Custom(ops.PrimitiveWithInfer):
922
939
 
923
940
  def _get_imply_type(self, reg_info, target):
924
941
  """Get imply_typ information."""
925
- # Get imply_type from reg_info["imply_type"]
926
- if isinstance(reg_info, dict) and isinstance(reg_info.get("imply_type"), str) and \
927
- reg_info["imply_type"].strip():
928
- return reg_info["imply_type"]
942
+ # Get imply_type from reg_info[IMPLY_TYPE]
943
+ if isinstance(reg_info, dict) and isinstance(reg_info.get(IMPLY_TYPE), str) and \
944
+ reg_info[IMPLY_TYPE].strip():
945
+ return reg_info[IMPLY_TYPE]
929
946
  # Infer imply_type from func_type
930
- func_type_to_imply_type = {"hybrid": "AKG", "akg": "AKG", "tbe": "TBE", "aicpu": "AiCPU", "pyfunc": target,
931
- "julia": target, "aot": "BiSheng" if target == "Ascend" else target}
932
- return func_type_to_imply_type.get(self.func_type, "AKG")
947
+ func_type_to_imply_type = {"hybrid": AKG, "akg": AKG, "tbe": TBE, "aicpu": "AiCPU", "pyfunc": target,
948
+ "julia": target, "aot": "BiSheng" if target == ASCEND else target}
949
+ return func_type_to_imply_type.get(self.func_type, AKG)
933
950
 
934
951
  def _save_attr(self, reg_info):
935
952
  """Save input_names and attr_names of current func."""
@@ -943,18 +960,18 @@ class Custom(ops.PrimitiveWithInfer):
943
960
  return value
944
961
 
945
962
  tensor_inputs = _get_value_list("inputs")
946
- attr = _get_value_list("attr")
963
+ attr = _get_value_list(KEY_ATTR)
947
964
  input_names = [] # include tensor input names and attr input names
948
965
  attr_names = []
949
966
  pure_input_names = []
950
967
  for item in tensor_inputs:
951
- if isinstance(item, dict) and item.get("name") is not None:
952
- input_names.append(item["name"])
953
- pure_input_names.append(item["name"])
968
+ if isinstance(item, dict) and item.get(KEY_NAME) is not None:
969
+ input_names.append(item[KEY_NAME])
970
+ pure_input_names.append(item[KEY_NAME])
954
971
  # attr is converted from inputs only when graph mode or when inputs name is also in reg info
955
972
  attr_to_input_safe = bool(input_names) or context.get_context("mode") == ms.GRAPH_MODE
956
973
  for item in attr:
957
- if isinstance(item, dict) and item.get("name") is not None:
974
+ if isinstance(item, dict) and item.get(KEY_NAME) is not None:
958
975
  # for custom op with function tbe, we always add attrs to inputs as we don't
959
976
  # deal with attr value here and leave them to the backend process to fit the
960
977
  # usual process of tbe op compiling in mindspore
@@ -963,9 +980,9 @@ class Custom(ops.PrimitiveWithInfer):
963
980
  # add attr name to input name only when the value of attr is None in reg info
964
981
  # as we need to get values of attrs from inputs
965
982
  if attr_to_input_safe and (self.func_type == "tbe" or item.get("value", None) is None):
966
- input_names.append(item["name"])
967
- attr_names.append(item["name"])
968
- cur_attr = {"input_names": input_names, "attr_names": attr_names, "pure_input_names": pure_input_names}
983
+ input_names.append(item[KEY_NAME])
984
+ attr_names.append(item[KEY_NAME])
985
+ cur_attr = {INPUT_NAMES: input_names, ATTR_NAMES: attr_names, "pure_input_names": pure_input_names}
969
986
  # If func does not have attr, save current attr.
970
987
  # Else, check if current attr is same as previous saved one.
971
988
  prev_attr_names = attr_names
@@ -974,13 +991,13 @@ class Custom(ops.PrimitiveWithInfer):
974
991
  if not isinstance(func_attr, dict):
975
992
  setattr(self.func, "func_attr", cur_attr)
976
993
  else:
977
- prev_attr_names = func_attr.get("attr_names")
994
+ prev_attr_names = func_attr.get(ATTR_NAMES)
978
995
  elif isinstance(self.func, str):
979
996
  func_attr = Custom.attr_dict.get(self.func)
980
997
  if not isinstance(func_attr, dict):
981
998
  Custom.attr_dict[self.func] = cur_attr
982
999
  else:
983
- prev_attr_names = func_attr.get("attr_names")
1000
+ prev_attr_names = func_attr.get(ATTR_NAMES)
984
1001
  if attr_names != prev_attr_names:
985
1002
  raise ValueError("{}, attr names set in registration information must be the same as previous saved one, "
986
1003
  "but got {} vs {}".format(self.log_prefix, attr_names, prev_attr_names))
@@ -989,23 +1006,23 @@ class Custom(ops.PrimitiveWithInfer):
989
1006
  """Add primitive_target to primitive's attr."""
990
1007
  registered_targets = self._get_registered_targets()
991
1008
  if self.func_type == "pyfunc":
992
- self.set_device("CPU")
993
- if registered_targets and registered_targets != ["CPU"]:
1009
+ self.set_device(CPU)
1010
+ if registered_targets and registered_targets != [CPU]:
994
1011
  logger.warning("{}, only supports CPU platform, but got registered target {}. "
995
1012
  "We will run it on CPU".format(self.log_prefix, registered_targets))
996
1013
  elif self.func_type == "aot":
997
1014
  if len(registered_targets) != 1:
998
1015
  logger.info("{}, target will be set according to context.".format(self.log_prefix))
999
- elif registered_targets == ["GPU"]:
1000
- self.set_device("GPU")
1001
- elif registered_targets == ["CPU"]:
1002
- self.set_device("CPU")
1016
+ elif registered_targets == [GPU]:
1017
+ self.set_device(GPU)
1018
+ elif registered_targets == [CPU]:
1019
+ self.set_device(CPU)
1003
1020
  elif self.func_type == "julia":
1004
- self.set_device("CPU")
1021
+ self.set_device(CPU)
1005
1022
  device_target = context.get_context('device_target')
1006
- if device_target == "CPU":
1023
+ if device_target == CPU:
1007
1024
  pass
1008
- elif device_target == "GPU" and registered_targets and registered_targets == ["CPU"]:
1025
+ elif device_target == GPU and registered_targets and registered_targets == [CPU]:
1009
1026
  logger.warning("{}, only supports CPU platform, but got registered target {}. "
1010
1027
  "We will run it on CPU".format(self.log_prefix, registered_targets))
1011
1028
  else:
@@ -1028,15 +1045,15 @@ class Custom(ops.PrimitiveWithInfer):
1028
1045
  elif isinstance(self.func, str):
1029
1046
  func_attr = Custom.attr_dict.get(self.func)
1030
1047
  if isinstance(func_attr, dict):
1031
- _add_prim_attr("input_names")
1032
- _add_prim_attr("attr_names")
1048
+ _add_prim_attr(INPUT_NAMES)
1049
+ _add_prim_attr(ATTR_NAMES)
1033
1050
  _add_prim_attr("pure_input_names")
1034
1051
  self._add_prim_target()
1035
1052
  if callable(self.func) and callable(self.out_shape):
1036
- if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == "autodiff":
1037
- self.add_prim_attr("autodiff", True)
1053
+ if hasattr(self.out_shape, "type") and getattr(self.out_shape, "type") == AUTO_DIFF:
1054
+ self.add_prim_attr(AUTO_DIFF, True)
1038
1055
  else:
1039
- self.add_prim_attr("autodiff", False)
1056
+ self.add_prim_attr(AUTO_DIFF, False)
1040
1057
 
1041
1058
  def _hybrid_autodiff(self, input_func_type):
1042
1059
  """generate backward op for a custom hybrid op"""
@@ -1052,7 +1069,7 @@ class Custom(ops.PrimitiveWithInfer):
1052
1069
  def infer_func(*args):
1053
1070
  return args[:inputs_num]
1054
1071
 
1055
- setattr(infer_func, "type", "autodiff")
1072
+ setattr(infer_func, "type", AUTO_DIFF)
1056
1073
  op = Custom(func=self.func, out_shape=infer_func, out_dtype=infer_func,
1057
1074
  func_type=input_func_type, bprop=True)
1058
1075
  self.bprop = grad_func(op)