mindspore 2.2.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_akg/akg/composite/build_module.py +104 -20
  3. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  4. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  5. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  6. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  7. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  8. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  9. mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
  10. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  11. mindspore/_akg/akg/utils/kernel_exec.py +41 -15
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
  13. mindspore/_akg/akg/utils/util.py +56 -1
  14. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +3 -3
  17. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  18. mindspore/_extends/graph_kernel/splitter.py +3 -2
  19. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  20. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  22. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  23. mindspore/_extends/parse/__init__.py +3 -2
  24. mindspore/_extends/parse/parser.py +6 -1
  25. mindspore/_extends/parse/standard_method.py +14 -11
  26. mindspore/_extends/remote/kernel_build_server.py +2 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  28. mindspore/bin/cache_admin +0 -0
  29. mindspore/bin/cache_server +0 -0
  30. mindspore/common/_utils.py +16 -0
  31. mindspore/common/api.py +1 -1
  32. mindspore/common/auto_dynamic_shape.py +81 -85
  33. mindspore/common/dump.py +1 -1
  34. mindspore/common/tensor.py +3 -20
  35. mindspore/config/op_info.config +1 -1
  36. mindspore/context.py +11 -4
  37. mindspore/dataset/engine/cache_client.py +8 -5
  38. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  39. mindspore/dataset/vision/transforms.py +21 -21
  40. mindspore/experimental/optim/adam.py +1 -1
  41. mindspore/gen_ops.py +1 -1
  42. mindspore/include/api/model.h +17 -0
  43. mindspore/include/api/status.h +8 -3
  44. mindspore/lib/libdnnl.so.2 +0 -0
  45. mindspore/lib/libmindspore.so +0 -0
  46. mindspore/lib/libmindspore_backend.so +0 -0
  47. mindspore/lib/libmindspore_common.so +0 -0
  48. mindspore/lib/libmindspore_core.so +0 -0
  49. mindspore/lib/libmindspore_glog.so.0 +0 -0
  50. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  51. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  52. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  53. mindspore/lib/libmindspore_shared_lib.so +0 -0
  54. mindspore/lib/libnnacl.so +0 -0
  55. mindspore/lib/libopencv_core.so.4.5 +0 -0
  56. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  57. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  58. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  59. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  60. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  61. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  62. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  63. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  64. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  65. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  66. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  67. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  68. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  71. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/libakg.so +0 -0
  76. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  77. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  78. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  79. mindspore/lib/plugin/cpu/libakg.so +0 -0
  80. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  81. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  82. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  83. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  84. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  85. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  86. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  87. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  88. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  89. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  90. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  91. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  92. mindspore/nn/cell.py +0 -3
  93. mindspore/nn/layer/activation.py +4 -5
  94. mindspore/nn/layer/conv.py +39 -23
  95. mindspore/nn/layer/flash_attention.py +54 -129
  96. mindspore/nn/layer/math.py +3 -7
  97. mindspore/nn/layer/rnn_cells.py +5 -5
  98. mindspore/nn/wrap/__init__.py +4 -2
  99. mindspore/nn/wrap/cell_wrapper.py +12 -3
  100. mindspore/numpy/utils_const.py +5 -5
  101. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  102. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  103. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  104. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  105. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  106. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  107. mindspore/ops/_utils/utils.py +2 -0
  108. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  109. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  110. mindspore/ops/function/array_func.py +10 -7
  111. mindspore/ops/function/grad/grad_func.py +0 -1
  112. mindspore/ops/function/nn_func.py +98 -9
  113. mindspore/ops/function/random_func.py +2 -1
  114. mindspore/ops/op_info_register.py +24 -21
  115. mindspore/ops/operations/__init__.py +6 -2
  116. mindspore/ops/operations/_grad_ops.py +25 -6
  117. mindspore/ops/operations/_inner_ops.py +155 -23
  118. mindspore/ops/operations/array_ops.py +9 -7
  119. mindspore/ops/operations/comm_ops.py +2 -2
  120. mindspore/ops/operations/custom_ops.py +85 -68
  121. mindspore/ops/operations/inner_ops.py +26 -3
  122. mindspore/ops/operations/math_ops.py +7 -6
  123. mindspore/ops/operations/nn_ops.py +193 -49
  124. mindspore/parallel/_parallel_serialization.py +10 -3
  125. mindspore/parallel/_tensor.py +4 -1
  126. mindspore/parallel/checkpoint_transform.py +13 -2
  127. mindspore/parallel/shard.py +17 -10
  128. mindspore/profiler/common/util.py +1 -0
  129. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  130. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  131. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  132. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  133. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  134. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  135. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  136. mindspore/profiler/parser/framework_parser.py +1 -1
  137. mindspore/profiler/parser/profiler_info.py +19 -0
  138. mindspore/profiler/profiling.py +46 -24
  139. mindspore/rewrite/api/pattern_engine.py +1 -1
  140. mindspore/rewrite/parsers/for_parser.py +7 -7
  141. mindspore/rewrite/parsers/module_parser.py +4 -4
  142. mindspore/rewrite/symbol_tree.py +1 -4
  143. mindspore/run_check/_check_version.py +5 -3
  144. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  145. mindspore/scipy/ops.py +55 -5
  146. mindspore/scipy/optimize/__init__.py +3 -2
  147. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  148. mindspore/train/callback/_summary_collector.py +1 -1
  149. mindspore/train/dataset_helper.py +1 -0
  150. mindspore/train/model.py +2 -2
  151. mindspore/train/serialization.py +97 -11
  152. mindspore/train/summary/_summary_adapter.py +1 -1
  153. mindspore/train/summary/summary_record.py +23 -7
  154. mindspore/version.py +1 -1
  155. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  156. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
  157. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  158. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  159. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  160. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  161. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  162. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  163. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  164. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  165. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  166. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  167. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  168. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  169. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  170. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,7 @@ import mindspore.common.dtype as mstype
21
21
  from mindspore.common.tensor import Tensor
22
22
  from mindspore import ops
23
23
  from mindspore.nn.cell import Cell
24
- from mindspore.ops._op_impl._custom_op.flash_attention.flash_attention_impl import get_flash_attention
25
24
  from mindspore.ops.operations.nn_ops import FlashAttentionScore
26
- from mindspore._c_expression import MSContext
27
25
 
28
26
  __all__ = ['FlashAttention']
29
27
 
@@ -46,25 +44,25 @@ class FlashAttention(Cell):
46
44
  Default 65536.
47
45
  next_block_num(int): A integer to define the number of blocks to look behind for local block sparse attention.
48
46
  Default 65536.
49
- tiling_stgy_name(str): A str to define tiling strategy of flash attention.
50
47
  dp(int): data parallel.
51
48
  Default 1.
52
49
  mp(int): model parallel.
53
50
  Default 1.
54
- high_precision(bool): This mode has higher precision but some performance loss.
51
+ high_precision(bool): This mode has higher precision but some performance loss. Only take effect on Ascend910A.
55
52
  Default False.
56
53
  have_attention_mask_batch(bool): indicates whether attention_mask contains the batch dimension.
57
54
  Default True
58
55
  alibi(bool): This parameter indicates whether the flashattention supports the Alibi.
59
56
  Default: False
57
+ use_mqa(bool): Using MQA if True, only take effect under 910B. Default: False.
60
58
 
61
59
 
62
60
  Inputs:
63
61
  - **query** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
64
62
  - **key** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
65
63
  - **value** (Tensor) - Tensor value (:class:`mstype.fp16` [batch_size, head_num, seq_length, head_dim])
66
- - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` [batch_size, seq_length,
67
- seq_length]): A matrix to pass masked information.
64
+ - **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp16` `mstype.uint8`
65
+ [batch_size, seq_length, seq_length]): A matrix to pass masked information.
68
66
 
69
67
  Outputs:
70
68
  A Tensor. The output of the attention with shape [batch_size, head_num, seq_length, head_dim]
@@ -97,56 +95,51 @@ class FlashAttention(Cell):
97
95
  dropout_rate=0.0,
98
96
  prev_block_num=65536,
99
97
  next_block_num=65536,
100
- tiling_stgy_name="sparse",
101
98
  dp=1,
102
99
  mp=1,
103
100
  high_precision=False,
104
101
  have_attention_mask_batch=True,
105
- alibi=False
102
+ alibi=False,
103
+ use_mqa=False
106
104
  ):
107
105
  super(FlashAttention, self).__init__()
108
106
 
109
107
  scaling_constant = math.sqrt(head_dim)
110
108
  if scaling_constant == 0:
111
109
  raise ValueError("the scaling constant must not be 0.")
112
- self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16)
110
+ self.dropout_rate = dropout_rate
111
+ self.alibi = alibi
112
+ self.have_attention_mask_batch = have_attention_mask_batch
113
113
 
114
- self.is_910A = MSContext.get_instance().get_ascend_soc_version() == "Ascend910"
115
- if self.is_910A:
116
- self.flash_attention = get_flash_attention(
117
- prev_block_num=prev_block_num,
118
- next_block_num=next_block_num,
119
- tiling_stgy_name=tiling_stgy_name,
120
- high_precision=high_precision
121
- )
122
- self.flash_attention.add_prim_attr("primitive_target", "Ascend")
123
- else:
124
- if alibi:
125
- raise ValueError(f"When soc_version is not Ascend910A, alibi must be False")
126
- self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
127
- self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
128
- self.reshape = ops.Reshape()
129
- self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
130
- self.zeros = ops.Zeros()
131
- self.attn_expand_dims = ops.ExpandDims().shard(((dp, 1, 1),))
132
- fa_strategies = ((dp, 1, mp),
133
- (dp, 1, mp),
134
- (dp, 1, mp),
114
+ self.transpose_4d_pre = ops.Transpose().shard(((dp, mp, 1, 1),))
115
+ self.transpose_4d_post = ops.Transpose().shard(((dp, 1, mp, 1),))
116
+ self.reshape = ops.Reshape()
117
+ self.zeros_like = ops.ZerosLike().shard(((dp, mp, 1, 1),))
118
+ self.zeros = ops.Zeros()
119
+ self.attn_cast = ops.Cast()
120
+ if use_mqa:
121
+ fa_strategies = ((dp, mp, 1, 1),
122
+ (dp, 1, 1, 1),
135
123
  (dp, 1, 1, 1))
136
- if dropout_rate > 1e-5:
137
- fa_strategies += ((dp, mp, 1, 1),)
138
- self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
139
- next_tokens=next_block_num,
140
- keep_prob=1 - dropout_rate,
141
- scale_value=1.0,
142
- inner_precise=0 if high_precision else 1).shard(fa_strategies)
124
+ else:
125
+ fa_strategies = ((dp, mp, 1, 1),
126
+ (dp, mp, 1, 1),
127
+ (dp, mp, 1, 1))
128
+ if self.alibi:
129
+ self.alibi_rescale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
130
+ self.alibi_rescale_factor = Tensor([scaling_constant], dtype=mstype.float16)
131
+ fa_strategies += ((dp, mp, 1, 1),)
132
+ if dropout_rate > 1e-5:
133
+ fa_strategies += ((dp, mp, 1, 1),)
134
+ fa_strategies += ((dp, 1, 1, 1),)
135
+ self.flash_attention = FlashAttentionScore(head_num=head_num, pre_tokens=prev_block_num,
136
+ next_tokens=next_block_num,
137
+ keep_prob=1 - dropout_rate,
138
+ scale_value=1. / scaling_constant,
139
+ inner_precise=0,
140
+ input_layout="BNSD").shard(fa_strategies)
143
141
 
144
- self.ones = ops.Ones()
145
- self.dim_mask = Tensor([1 for _ in range(head_dim)], dtype=mstype.int8)
146
- self.scale_mul = ops.Mul().shard(((dp, mp, 1, 1), (1,)))
147
142
  self.dropout_rate = dropout_rate
148
- self.have_attention_mask_batch = have_attention_mask_batch
149
- self.alibi = alibi
150
143
  if self.dropout_rate > 1e-5:
151
144
  self.keep_prob = Tensor(1 - self.dropout_rate, dtype=mstype.float16)
152
145
  self.fill_v2 = ops.FillV2().shard(((dp, mp, 1, 1), ()))
@@ -162,46 +155,7 @@ class FlashAttention(Cell):
162
155
  such as MatMul. Default: None.
163
156
  :return:
164
157
  """
165
- if in_strategy is None:
166
- # default: dp=1, mp=1, construct inputs only contain query, key, value
167
- in_strategy = (
168
- (1, 1, 1, 1),
169
- (1, 1, 1, 1),
170
- (1, 1, 1, 1),
171
- )
172
158
  self.flash_attention.shard(in_strategy)
173
- dp = in_strategy[0][0]
174
- mp = in_strategy[0][1]
175
- self.flash_attention.add_prim_attr("dev_matrix_shape", [dp, mp, 1, 1])
176
- inputs_tensor_map = [
177
- [3, 2, 1, 0],
178
- [3, 2, 1, 0],
179
- [3, 2, 1, 0],
180
- ]
181
- if self.have_attention_mask_batch:
182
- inputs_tensor_map.append([3, 1, 0])
183
- else:
184
- inputs_tensor_map.append([-1, 1, 0])
185
-
186
- input_empty_args_num = 2
187
- # dropout_mask
188
- if self.dropout_rate > 1e-5:
189
- input_empty_args_num -= 1
190
- inputs_tensor_map.append([3, 2, 1, 0])
191
-
192
- if self.alibi:
193
- input_empty_args_num -= 1
194
- inputs_tensor_map.append([3, 2, 1, 0])
195
-
196
- self.flash_attention.add_prim_attr("inputs_tensor_map", inputs_tensor_map)
197
-
198
- self.flash_attention.add_prim_attr("outputs_tensor_map", [
199
- [3, 2, 1, 0], # O
200
- [3, 2, 1], # L
201
- [3, 2, 1] # M
202
- ])
203
- self.flash_attention.add_prim_attr("as_loss_divisor", 0)
204
- self.flash_attention.add_prim_attr("empty_mirror_ops", input_empty_args_num)
205
159
 
206
160
  def construct(self, query, key, value, attn_mask=None, alibi_mask=None):
207
161
  """FlashAttention forward
@@ -212,53 +166,24 @@ class FlashAttention(Cell):
212
166
  :param alibi_mask: [bsz, head_num, 1, seq_len], if not None
213
167
  :return: output [bsz, head_num, seq_len, head_dim]
214
168
  """
215
- query = self.scale_mul(query, self.scale_factor)
216
- bsz, head_num, seq_len, head_dim = query.shape
217
- _, k_head_num, k_seq_len, _ = key.shape
218
- _, v_head_num, v_seq_len, _ = value.shape
219
- if head_num != k_head_num or head_num != v_head_num:
220
- raise ValueError(
221
- "the head_num of query, key and value must be the same, "
222
- "If different head_num are used, users need to change themselves to be same by tile.")
223
- if seq_len % 16 != 0 or k_seq_len % 16 != 0 or k_seq_len != v_seq_len:
224
- raise ValueError(
225
- "query, key, value seq_len must be a multiple of 16, and key seq_len, value seq_len must be the same.")
226
-
227
- if head_dim > 304:
228
- raise ValueError(
229
- "the head_dim must be less than 304, otherwise the ub would be OOM.")
230
-
231
- if self.is_910A:
232
- # 910A -- FlashAttentionPrimtive
233
- if self.dropout_rate > 1e-5:
234
- drop_mask_bits = self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob)
235
- tensor_shape = Tensor((bsz, head_num, seq_len, seq_len), mstype.int32)
236
- ones = self.fill_v2(tensor_shape, self.tensor_one)
237
- ones = self.depend(ones, query)
238
- drop_mask = self.do_dropout(ones, drop_mask_bits, self.keep_prob)
239
- else:
240
- drop_mask = None
241
- output, _, _ = self.flash_attention(query, key, value, attn_mask, drop_mask, alibi_mask)
169
+ bsz, head_num, seq_len, _ = query.shape
170
+ # 910B -- FlashAttentionScore
171
+ if self.dropout_rate > 1e-5:
172
+ drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
173
+ (bsz, head_num, seq_len, seq_len // 8))
242
174
  else:
243
- # FlashAttentionScore
244
- # Useless input, just for binary calls.
245
- if self.dropout_rate > 1e-5:
246
- drop_mask_bits = self.reshape(self.drop_gen_mask((bsz, head_num, seq_len, seq_len), self.keep_prob),
247
- (bsz, head_num, seq_len, seq_len // 8))
248
- else:
249
- drop_mask_bits = None
250
- # (B, N, S, D) -> (B, S, H)
251
- query = self.reshape(self.transpose_4d_pre(query, (0, 2, 1, 3)), (bsz, seq_len, -1))
252
- key = self.reshape(self.transpose_4d_pre(key, (0, 2, 1, 3)), (bsz, seq_len, -1))
253
- value = self.reshape(self.transpose_4d_pre(value, (0, 2, 1, 3)), (bsz, seq_len, -1))
254
- attn_mask = self.attn_expand_dims(attn_mask, 1)
255
- output, _, _ = self.flash_attention(query,
256
- key,
257
- value,
258
- attn_mask,
259
- drop_mask_bits,
260
- None,
261
- None)
262
- output = self.transpose_4d_post(self.reshape(output, (bsz, seq_len, head_num, head_dim)), (0, 2, 1, 3))
263
-
175
+ drop_mask_bits = None
176
+ if self.alibi:
177
+ alibi_mask = self.alibi_rescale_mul(alibi_mask, self.cast(self.alibi_rescale_factor, alibi_mask.dtype))
178
+ # (B, S, S) -> (B, 1, S, S)
179
+ if self.have_attention_mask_batch:
180
+ attn_mask = self.cast(self.reshape(attn_mask, (bsz, 1, seq_len, seq_len)), mstype.uint8)
181
+ _, _, _, output = self.flash_attention(query,
182
+ key,
183
+ value,
184
+ alibi_mask,
185
+ drop_mask_bits,
186
+ None,
187
+ attn_mask,
188
+ None)
264
189
  return output
@@ -375,9 +375,6 @@ class DiGamma(Cell):
375
375
  nan, real_result)
376
376
 
377
377
 
378
- eps_fp32 = Tensor(np.finfo(np.float32).eps, mstype.float32)
379
-
380
-
381
378
  def _while_helper_func(cond, body, vals):
382
379
  while cond(vals).any():
383
380
  vals = body(vals)
@@ -394,7 +391,7 @@ def _igamma_series(ax, x, a, enabled):
394
391
  select = P.Select()
395
392
 
396
393
  # If more data types are supported, this epsilon need to be selected.
397
- epsilon = eps_fp32
394
+ epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
398
395
 
399
396
  def cond(vals):
400
397
  enabled = vals[0]
@@ -443,7 +440,7 @@ def _igammac_continued_fraction(ax, x, a, enabled):
443
440
  select = P.Select()
444
441
 
445
442
  # If more data types are supported, this epsilon need to be selected.
446
- epsilon = eps_fp32
443
+ epsilon = Tensor(np.finfo(np.float32).eps, mstype.float32)
447
444
 
448
445
  def cond(vals):
449
446
  enabled = vals[0]
@@ -620,8 +617,7 @@ class IGamma(Cell):
620
617
  x = F.broadcast_to(x, para_shape)
621
618
  a = F.broadcast_to(a, para_shape)
622
619
  x_is_zero = self.equal(x, 0)
623
- log_maxfloat = self.log_maxfloat32
624
- underflow = self.less(ax, self.neg(log_maxfloat))
620
+ underflow = self.less(ax, self.neg(self.log_maxfloat32))
625
621
  ax = self.exp(ax)
626
622
  enabled = self.logicalnot(self.logicalor(self.logicalor(x_is_zero, domain_error), underflow))
627
623
  output = self.select(use_igammac,
@@ -83,7 +83,7 @@ def _check_lstmcell_init(func):
83
83
 
84
84
 
85
85
  def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
86
- '''RNN cell function with tanh activation'''
86
+ """RNN cell function with tanh activation"""
87
87
  if b_ih is None:
88
88
  igates = P.MatMul(False, True)(inputs, w_ih)
89
89
  hgates = P.MatMul(False, True)(hidden, w_hh)
@@ -94,7 +94,7 @@ def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
94
94
 
95
95
 
96
96
  def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
97
- '''RNN cell function with relu activation'''
97
+ """RNN cell function with relu activation"""
98
98
  if b_ih is None:
99
99
  igates = P.MatMul(False, True)(inputs, w_ih)
100
100
  hgates = P.MatMul(False, True)(hidden, w_hh)
@@ -105,7 +105,7 @@ def _rnn_relu_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
105
105
 
106
106
 
107
107
  def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
108
- '''LSTM cell function'''
108
+ """LSTM cell function"""
109
109
  hx, cx = hidden
110
110
  if b_ih is None:
111
111
  gates = P.MatMul(False, True)(inputs, w_ih) + P.MatMul(False, True)(hx, w_hh)
@@ -125,7 +125,7 @@ def _lstm_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
125
125
 
126
126
 
127
127
  def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
128
- '''GRU cell function'''
128
+ """GRU cell function"""
129
129
  if b_ih is None:
130
130
  gi = P.MatMul(False, True)(inputs, w_ih)
131
131
  gh = P.MatMul(False, True)(hidden, w_hh)
@@ -144,7 +144,7 @@ def _gru_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
144
144
 
145
145
 
146
146
  class RNNCellBase(Cell):
147
- '''Basic class for RNN Cells'''
147
+ """Basic class for RNN Cells"""
148
148
  def __init__(self, input_size: int, hidden_size: int, has_bias: bool, num_chunks: int,
149
149
  dtype=mstype.float32):
150
150
  super().__init__()
@@ -20,7 +20,8 @@ Use the Wrapper to combine the loss or build the training steps.
20
20
  from __future__ import absolute_import
21
21
 
22
22
  from mindspore.nn.wrap.cell_wrapper import ForwardValueAndGrad, TrainOneStepCell, WithLossCell, WithGradCell, \
23
- WithEvalCell, ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, MicroBatchInterleaved, PipelineCell
23
+ WithEvalCell, ParameterUpdate, GetNextSingleOp, VirtualDatasetCellTriple, MicroBatchInterleaved, PipelineCell, \
24
+ GradAccumulationCell
24
25
  from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell,\
25
26
  DynamicLossScaleUpdateCell, FixedLossScaleUpdateCell
26
27
  from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
@@ -42,5 +43,6 @@ __all__ = [
42
43
  "ParameterUpdate",
43
44
  "DynamicLossScaleUpdateCell",
44
45
  "FixedLossScaleUpdateCell",
45
- "VirtualDatasetCellTriple"
46
+ "VirtualDatasetCellTriple",
47
+ "GradAccumulationCell"
46
48
  ]
@@ -644,6 +644,9 @@ class PipelineCell(Cell):
644
644
  self.micro_inputs = nn.CellList()
645
645
  self.micro_size = micro_size
646
646
  self.add_list = []
647
+ if not isinstance(network, Cell):
648
+ raise TypeError("For 'PipelineCell', the argument 'network' must cell type, "
649
+ "but got the type : {}.".format(type(network)))
647
650
  if not isinstance(micro_size, int):
648
651
  raise TypeError("For 'PipelineCell', the argument 'micro_size' must be integer, "
649
652
  "but got the type : {}.".format(type(micro_size)))
@@ -670,7 +673,7 @@ class PipelineCell(Cell):
670
673
 
671
674
  class GradAccumulationCell(Cell):
672
675
  """
673
- Wrap the network with Micro Batch.
676
+ Wrap the network with Micro Batch to enable the grad accumulation in semi_auto_parallel/auto_parallel mode.
674
677
 
675
678
  Args:
676
679
  network (Cell): The target network to wrap.
@@ -680,8 +683,11 @@ class GradAccumulationCell(Cell):
680
683
  ``Ascend`` ``GPU``
681
684
 
682
685
  Examples:
683
- >>> net = Net()
684
- >>> net = GradAccumulationCell(net, 4)
686
+ >>> import mindspore.nn as nn
687
+ >>> # Define the network structure of LeNet5. Refer to
688
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
689
+ >>> net = LeNet5()
690
+ >>> net = nn.GradAccumulationCell(net, 4)
685
691
  """
686
692
  def __init__(self, network, micro_size):
687
693
  super(GradAccumulationCell, self).__init__(auto_prefix=False)
@@ -689,6 +695,9 @@ class GradAccumulationCell(Cell):
689
695
  self.micro_inputs = nn.CellList()
690
696
  self.micro_size = micro_size
691
697
  self.add_list = []
698
+ if not isinstance(network, Cell):
699
+ raise TypeError("For 'GradAccumulationCell', the argument 'network' must cell type, "
700
+ "but got the type : {}.".format(type(network)))
692
701
  if not isinstance(micro_size, int):
693
702
  raise TypeError("For 'GradAccumulationCell', the argument 'micro_size' must be integer, "
694
703
  "but got the type : {}.".format(type(micro_size)))
@@ -143,8 +143,8 @@ def _infer_out_shape(*shapes):
143
143
  shape_out = list()
144
144
  max_len = max([len(it) for it in shapes])
145
145
  for i in range(max_len):
146
- items = [it[i-max_len+len(it)] if i-max_len +
147
- len(it) >= 0 else 1 for it in shapes]
146
+ items = [
147
+ it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes]
148
148
  max_size = 0 if 0 in items else max(items)
149
149
  _check()
150
150
  shape_out.append(max_size)
@@ -158,8 +158,8 @@ def _can_broadcast(*shapes):
158
158
  """
159
159
  max_len = max([len(it) for it in shapes])
160
160
  for i in range(max_len):
161
- items = [it[i-max_len+len(it)] if i-max_len +
162
- len(it) >= 0 else 1 for it in shapes]
161
+ items = [
162
+ it[i - max_len + len(it)] if i - max_len + len(it) >= 0 else 1 for it in shapes]
163
163
  max_size = 0 if 0 in items else max(items)
164
164
  if any(item not in (1, max_size) for item in items):
165
165
  return False
@@ -399,7 +399,7 @@ def _broadcast_tuples(tup1, tup2):
399
399
  if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)):
400
400
  raise TypeError("input shift and axis must be tuple or list or int.")
401
401
  if len(tup1) == len(tup2) or len(tup1) == 1 or len(tup2) == 1:
402
- return None
402
+ return
403
403
  raise ValueError("shape mismatch: objects cannot be broadcast to a single shape")
404
404
 
405
405
  tup1 = (tup1,) if isinstance(tup1, int) else tup1
@@ -203,7 +203,7 @@ def get_bprop_index_put(self):
203
203
  if is_ascend:
204
204
  indices_ms = [convert_idx_positive(indices_ms[i], x1.shape[i]) for i in range(len(indices_ms))]
205
205
  indices_me = stack(indices_ms)
206
- indices_grad = F.transpose(indices_me, F.make_range(F.rank(indices_me)-1, -1, -1))
206
+ indices_grad = F.transpose(indices_me, F.make_range(F.rank(indices_me) - 1, -1, -1))
207
207
  values_grad = gather_nd(dout, indices_grad)
208
208
  if equal(cast(x2.shape[0], mstype.int32), Tensor(1)):
209
209
  values_grad = values_grad.sum().reshape(1)
@@ -19,7 +19,7 @@ from mindspore.ops import functional as F
19
19
  from mindspore.ops import operations as P
20
20
  from mindspore.ops.composite import multitype_ops as C
21
21
  from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
22
- from mindspore.ops._grad_experimental.grad_base import bprops
22
+ from mindspore.ops._grad_experimental.grad_base import bprops, bprop_getters
23
23
  from mindspore.common import dtype as mstype
24
24
 
25
25
  get_dtype = P.DType()
@@ -193,7 +193,7 @@ def bprop_tensor_move(x, out, dout):
193
193
  return (dout,)
194
194
 
195
195
 
196
- @bprops.register("DictInplaceSetItem")
196
+ @bprop_getters.register("DictInplaceSetItem")
197
197
  def get_bprop_dict_inplace_setitem(self):
198
198
  """Generate bprop for dict inplace pop"""
199
199
 
@@ -135,7 +135,7 @@ def get_bprop_matrix_triangular_solve(self):
135
135
 
136
136
  def bprop(matrix, rhs, out, dout):
137
137
  grad_rhs = matrix_triangular_solve_op(matrix, dout)
138
- if matrix.dtype == mstype.complex64 or matrix.dtype == mstype.complex128:
138
+ if matrix.dtype in (mstype.complex64, mstype.complex128):
139
139
  grad_rhs_temp = _adjoint(grad_rhs)
140
140
  out_temp = _adjoint(out)
141
141
  else:
@@ -156,14 +156,14 @@ def get_bprop_matrix_triangular_solve(self):
156
156
  grad_matrix = mat_mul_op(grad_rhs, out_temp)
157
157
  grad_matrix = neg_op(grad_matrix)
158
158
  if lower_a:
159
- if grad_matrix.dtype == mstype.complex64 or grad_matrix.dtype == mstype.complex128:
159
+ if grad_matrix.dtype in (mstype.complex64, mstype.complex128):
160
160
  grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), -1, 0)
161
161
  grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), -1, 0)
162
162
  grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
163
163
  else:
164
164
  grad_matrix = matrix_band_part_op(grad_matrix, -1, 0)
165
165
  else:
166
- if grad_matrix.dtype == mstype.complex64 or grad_matrix.dtype == mstype.complex128:
166
+ if grad_matrix.dtype in (mstype.complex64, mstype.complex128):
167
167
  grad_matrix_real = matrix_band_part_op(real_op(grad_matrix), 0, -1)
168
168
  grad_matrix_imag = matrix_band_part_op(imag_op(grad_matrix), 0, -1)
169
169
  grad_matrix = complex_op(grad_matrix_real, grad_matrix_imag)
@@ -219,7 +219,7 @@ def get_bprop_matrix_solve(self):
219
219
  @_primexpr
220
220
  def _generate_perm_matrix_solve_ls(x_dim):
221
221
  perm = tuple(range(x_dim - 2))
222
- perm = perm + (x_dim-1, x_dim-2)
222
+ perm = perm + (x_dim - 1, x_dim - 2)
223
223
  return perm
224
224
 
225
225
 
@@ -647,20 +647,21 @@ def _fft_rank_offset(norm_shape, rank):
647
647
  @_primexpr
648
648
  def _fft_with_size_back_norm(norm_shape, norm, inverse, rank):
649
649
  """generate reverse term for fft_with_size"""
650
+ norm_ = None
650
651
  if inverse is False:
651
652
  if norm == "forward":
652
- norm_ = 1 / _fft_rank_offset(norm_shape, rank)
653
- if norm == "backward":
654
- norm_ = 1 * _fft_rank_offset(norm_shape, rank)
655
- if norm == "ortho":
656
- norm_ = 1
657
- if inverse is True:
653
+ norm_ = 1.0 / _fft_rank_offset(norm_shape, rank)
654
+ elif norm == "backward":
655
+ norm_ = 1.0 * _fft_rank_offset(norm_shape, rank)
656
+ elif norm == "ortho":
657
+ norm_ = 1.0
658
+ else:
658
659
  if norm == "forward":
659
- norm_ = 1 * _fft_rank_offset(norm_shape, rank)
660
- if norm == "backward":
661
- norm_ = 1 / _fft_rank_offset(norm_shape, rank)
662
- if norm == "ortho":
663
- norm_ = 1
660
+ norm_ = 1.0 * _fft_rank_offset(norm_shape, rank)
661
+ elif norm == "backward":
662
+ norm_ = 1.0 / _fft_rank_offset(norm_shape, rank)
663
+ elif norm == "ortho":
664
+ norm_ = 1.0
664
665
  return norm_
665
666
 
666
667
 
@@ -670,9 +671,9 @@ def _rfft_norm(norm_shape, norm, rank):
670
671
  norm_ = 1.0
671
672
  if norm == "forward":
672
673
  norm_ = 1 / _fft_rank_offset(norm_shape, rank)
673
- if norm == "backward":
674
- norm_ = 1
675
- if norm == "ortho":
674
+ elif norm == "backward":
675
+ norm_ = 1.0
676
+ elif norm == "ortho":
676
677
  norm_ = 1 / np.sqrt(_fft_rank_offset(norm_shape, rank))
677
678
  return norm_
678
679
 
@@ -358,10 +358,10 @@ def get_bprop_ragged_tensor_to_sparse(self):
358
358
  split.append(zeros_like(i))
359
359
  all_d = (split, ragged_values_grad)
360
360
  return all_d
361
- split = ()
361
+ split_ = ()
362
362
  for i in enumerate(rt_nested_splits):
363
- split = split + (zeros_like(i),)
364
- all_d = (split, ragged_values_grad)
363
+ split_ = split_ + (zeros_like(i),)
364
+ all_d = (split_, ragged_values_grad)
365
365
  return all_d
366
366
 
367
367
  return bprop
@@ -29,9 +29,9 @@ add_op_info = AiCPURegOp("Add") \
29
29
  .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
30
30
  .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
31
31
  .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
32
- .dtype_format(DataType.U16_Default, DataType.I16_Default, DataType.I16_Default) \
33
- .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default) \
34
- .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I64_Default) \
32
+ .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
33
+ .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
34
+ .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
35
35
  .dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
36
36
  .dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
37
37
  .get_op_info()
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -24,11 +24,30 @@ lsap_op_info = AiCPURegOp("LinearSumAssignment") \
24
24
  .input(2, 'maximize', "required") \
25
25
  .output(0, "row_ind", "required") \
26
26
  .output(1, "col_ind", "required") \
27
- .attr("cust_aicpu", "str") \
28
27
  .dtype_format(DataType.F64_Default, DataType.I64_Default,
29
28
  DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
30
29
  .dtype_format(DataType.F32_Default, DataType.I64_Default,
31
30
  DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
31
+ .dtype_format(DataType.F16_Default, DataType.I64_Default,
32
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
33
+ .dtype_format(DataType.BOOL_Default, DataType.I64_Default,
34
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
35
+ .dtype_format(DataType.I16_Default, DataType.I64_Default,
36
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
37
+ .dtype_format(DataType.I32_Default, DataType.I64_Default,
38
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
39
+ .dtype_format(DataType.I64_Default, DataType.I64_Default,
40
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
41
+ .dtype_format(DataType.I8_Default, DataType.I64_Default,
42
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
43
+ .dtype_format(DataType.U16_Default, DataType.I64_Default,
44
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
45
+ .dtype_format(DataType.U32_Default, DataType.I64_Default,
46
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
47
+ .dtype_format(DataType.U64_Default, DataType.I64_Default,
48
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
49
+ .dtype_format(DataType.U8_Default, DataType.I64_Default,
50
+ DataType.BOOL_Default, DataType.I64_Default, DataType.I64_Default,) \
32
51
  .get_op_info()
33
52
 
34
53
 
@@ -77,10 +77,12 @@ def get_broadcast_shape(x_shape, y_shape, prim_name, arg_name1="x", arg_name2="y
77
77
  broadcast_shape = list(broadcast_shape_front) + broadcast_shape_back
78
78
  return broadcast_shape
79
79
 
80
+
80
81
  def dim_not_equal(dim1, dim2):
81
82
  """Compare dim in shape"""
82
83
  return dim1 != dim2 and dim1 >= 0 and dim2 >= 0
83
84
 
85
+
84
86
  def get_concat_offset(x_shp, x_type, axis, prim_name):
85
87
  """for concat and concatoffset check args and compute offset"""
86
88
  validator.check_value_type("shape", x_shp, [tuple, list], prim_name)
@@ -1255,7 +1255,8 @@ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
1255
1255
  index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
1256
1256
  index = F.broadcast_to(index, data.shape)
1257
1257
  value = F.cast(value, F.dtype(data))
1258
- value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
1258
+ while value.ndim < data.ndim:
1259
+ value = value.unsqueeze(-1)
1259
1260
  value = F.broadcast_to(value, data.shape)
1260
1261
  result = F.select(index, value, data)
1261
1262
  return result
@@ -161,7 +161,7 @@ def _tuple_getitem_by_slice(data, slice_index):
161
161
  if start is None:
162
162
  start = 0 if step >= 1 else -1
163
163
  if stop is None:
164
- stop = (2**31-1) if step >= 1 else -(2**31-1)
164
+ stop = (2**31 - 1) if step >= 1 else -(2**31 - 1)
165
165
  return sequence_slice(data, start, stop, step)
166
166
  return _tuple_slice(data, slice_index)
167
167
 
@@ -236,7 +236,7 @@ def _list_getitem_by_slice(data, slice_index):
236
236
  if start is None:
237
237
  start = 0 if step >= 1 else -1
238
238
  if stop is None:
239
- stop = (2**31-1) if step >= 1 else -(2**31-1)
239
+ stop = (2**31 - 1) if step >= 1 else -(2**31 - 1)
240
240
  return sequence_slice(data, start, stop, step)
241
241
  return _list_slice(data, slice_index)
242
242