optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -26,141 +26,241 @@ from typing import Dict, Optional, Tuple
26
26
 
27
27
  import torch
28
28
  from torch import nn
29
+ from transformers import PretrainedConfig
29
30
  from transformers.modeling_outputs import (
30
31
  BaseModelOutputWithPast,
31
32
  )
32
33
 
34
+ from ....utils import logging
33
35
  from ...cache_utils import RebelDynamicCache
36
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+ """
41
+ ##############################################################################
42
+ # RBLN custom operation (python interface)
43
+ # torch.compile custom operation
44
+ # torch.library.define - kernel declaration
45
+ # torch.library.impl - kernel implementation
46
+ # torch.library.impl_abstract - symbolic trace
47
+ ##############################################################################
48
+ """
49
+
50
+ # RBLN custom op(flash attention decode)
51
+ torch.library.define(
52
+ "rbln_custom_ops::flash_attn_decode",
53
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
54
+ )
34
55
 
35
56
 
36
- class DecoderOnlyWrapper(torch.nn.Module):
37
- def __init__(self, model, max_seq_len):
38
- super().__init__()
39
- self.config = model.config
40
- self.model = model.model
41
- self.lm_head = model.lm_head
57
+ @torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
58
+ def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
59
+ """
60
+ WORKAROUND:
61
+ Partition is declared as an argument to the function, even though it is
62
+ not actually used in the CPU implementation, this allows the rbln compiler
63
+ to perform flash attention operations with partition as an argument.
64
+ """
65
+ assert kcache.dim() == k.dim()
66
+ assert vcache.dim() == v.dim()
67
+ assert k.size(-2) == v.size(-2)
68
+ assert partition.dim() == 1
69
+ b = 0
70
+ if seq.dim() == 1:
71
+ s = seq[0]
72
+ elif seq.dim() == 0:
73
+ s = seq
74
+ else:
75
+ assert False
76
+ e = s + k.size(-2)
77
+ updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
78
+ updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
79
+ attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
80
+ attn_weight = attn_weight + mask
81
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
82
+ attn_output = torch.matmul(attn_weight, updated_v)
83
+ return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
84
+
85
+
86
+ @torch.library.impl_abstract("rbln_custom_ops::flash_attn_decode")
87
+ def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
88
+ return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
89
+
90
+
91
+ # RBLN custom op(flash attention prefill)
92
+ torch.library.define(
93
+ "rbln_custom_ops::flash_attn_prefill",
94
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
95
+ )
96
+
97
+
98
+ @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
99
+ def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition):
100
+ """
101
+ WORKAROUND:
102
+ Partition is declared as an argument to the function, even though it is
103
+ not actually used in the CPU implementation, this allows the rbln compiler
104
+ to perform flash attention operations with partition as an argument.
105
+ """
106
+ assert kcache.dim() == k.dim()
107
+ assert vcache.dim() == v.dim()
108
+ assert k.size(-2) == v.size(-2)
109
+ assert partition.dim() == 1
110
+ if batch.dim() == 1:
111
+ b = batch[0]
112
+ elif batch.dim() == 0:
113
+ b = batch
114
+ else:
115
+ assert False
116
+ if seq.dim() == 1:
117
+ s = seq[0]
118
+ elif seq.dim() == 0:
119
+ s = seq
120
+ else:
121
+ assert False
122
+ e = s + k.size(-2)
123
+ updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
124
+ updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
125
+ attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
126
+ attn_weight = attn_weight + mask
127
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
128
+ attn_output = torch.matmul(attn_weight, updated_v)
129
+ return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
130
+
42
131
 
43
- self.head_dim = (
44
- self.config.head_dim
45
- if hasattr(self.config, "head_dim")
46
- else self.config.hidden_size // self.config.num_attention_heads
132
+ @torch.library.impl_abstract("rbln_custom_ops::flash_attn_prefill")
133
+ def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
134
+ return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
135
+
136
+
137
+ # RBLN custom op(cache update)
138
+ torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
139
+
140
+
141
+ @torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
142
+ def rbln_cache_update_cpu(cache, value, batch, seq):
143
+ updated_cache = cache[batch].slice_scatter(value, dim=-2, start=batch[0], end=batch[0] + seq[0])
144
+ return updated_cache
145
+
146
+
147
+ @torch.library.impl_abstract("rbln_custom_ops::rbln_cache_update")
148
+ def rbln_cache_update_abstract(cache, value, batch, seq):
149
+ return torch.empty_like(cache)
150
+
151
+
152
+ class DecoderOnlyAttention:
153
+ def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
154
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
155
+ key_state = key_state.unsqueeze(2)
156
+ value_state = value_state.unsqueeze(2)
157
+ attn_mask = attn_mask.unsqueeze(2)
158
+
159
+ query_state = query_state.view(
160
+ 1,
161
+ self.num_key_value_heads,
162
+ self.num_heads // self.num_key_value_heads,
163
+ -1,
164
+ self.head_dim,
47
165
  )
48
- self.max_position_embeddings = (
49
- self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
166
+
167
+ key_state, value_state = past_key_value.update(
168
+ key_state, value_state, self.layer_idx, batch_idx, read_first_step=is_prefill
50
169
  )
51
- self.max_seq_len = max_seq_len
52
- self.rope_scaling = getattr(self.config, "rope_scaling", None)
53
- self.rotary_emb = self._init_rope()
54
170
 
55
- def _init_rope(self):
56
- if self.rope_scaling is None:
57
- rotary_emb = RotaryEmbedding(
58
- self.head_dim,
59
- max_position_embeddings=self.max_position_embeddings,
60
- base=self.config.rope_theta,
61
- )
62
- else:
63
- scaling_type = self.rope_scaling["type"]
64
- scaling_factor = self.rope_scaling["factor"]
65
- if scaling_type == "linear":
66
- rotary_emb = LinearScalingRotaryEmbedding(
67
- self.head_dim,
68
- max_position_embeddings=self.max_position_embeddings,
69
- scaling_factor=scaling_factor,
70
- base=self.config.rope_theta,
71
- max_seq_len=self.max_seq_len,
72
- )
73
- elif scaling_type == "dynamic":
74
- rotary_emb = DynamicNTKScalingRotaryEmbedding(
75
- self.head_dim,
76
- max_position_embeddings=self.max_position_embeddings,
77
- scaling_factor=scaling_factor,
78
- base=self.config.rope_theta,
79
- max_seq_len=self.max_seq_len,
80
- )
81
- else:
82
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
171
+ attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
172
+ attn_weight += attn_mask
173
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_state.dtype)
174
+ attn_output = torch.matmul(attn_weight, value_state)
83
175
 
84
- return rotary_emb
176
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
177
+ attn_output = attn_output.transpose(1, 2).contiguous()
178
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
85
179
 
86
- def get_forward_dict(self):
87
- forward_dict = {
88
- "wrapper": DecoderOnlyModel.forward,
89
- "model": DecoderOnlyDecoderLayer.forward,
90
- "decoder_layer": DecoderOnlyAttention.forward,
91
- }
92
- return forward_dict
180
+ return attn_output, key_state, value_state
93
181
 
94
182
  def forward(
95
183
  self,
96
- input_ids_or_inputs_embeds,
97
- attention_mask,
98
- cache_position,
99
- batch_position,
100
- query_idx,
101
- *past_key_values,
102
- ):
103
- if input_ids_or_inputs_embeds.shape[1] == 1:
104
- rbln_batch_position = None
105
- else:
106
- rbln_batch_position = batch_position
184
+ hidden_states: torch.Tensor,
185
+ attention_mask: Optional[torch.Tensor] = None,
186
+ position_ids: Optional[torch.LongTensor] = None,
187
+ past_key_value: Optional[RebelDynamicCache] = None,
188
+ batch_index: Optional[torch.Tensor] = None,
189
+ output_attentions: bool = False,
190
+ cos: Optional[torch.Tensor] = None,
191
+ sin: Optional[torch.Tensor] = None,
192
+ **kwargs,
193
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
+ bsz, q_len, _ = hidden_states.size()
195
+ query_states = self.q_proj(hidden_states)
196
+ key_states = self.k_proj(hidden_states)
197
+ value_states = self.v_proj(hidden_states)
107
198
 
108
- if input_ids_or_inputs_embeds.ndim == 2:
109
- # input_ids
110
- input_ids = input_ids_or_inputs_embeds
111
- inputs_embeds = None
112
- elif input_ids_or_inputs_embeds.ndim == 3:
113
- # inputs_embeds
114
- input_ids = None
115
- inputs_embeds = input_ids_or_inputs_embeds
116
- else:
117
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
199
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
200
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
201
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
118
202
 
119
- # Formatting list of past_kv to DynamicCache class.
120
- past_key_values = RebelDynamicCache.from_input_format(
121
- cache_position,
122
- self.config.num_hidden_layers,
123
- *past_key_values,
124
- )
203
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
125
204
 
126
- forward_dict = self.get_forward_dict()
127
- outputs = forward_dict["wrapper"](
128
- self.model,
129
- input_ids=input_ids,
130
- inputs_embeds=inputs_embeds,
131
- attention_mask=attention_mask,
132
- position_ids=cache_position,
133
- past_key_values=past_key_values,
134
- batch_ids=rbln_batch_position,
135
- rotary_pos_emb=self.rotary_emb,
136
- forward_dict=forward_dict,
137
- )
205
+ # Decoder (bsz > 1)
206
+ if bsz > 1:
207
+ iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
208
+ for b in range(bsz):
209
+ attn_output, key_state, value_state = DecoderOnlyAttention._attn(
210
+ self,
211
+ query_states[b].unsqueeze(0),
212
+ key_states[b].unsqueeze(0),
213
+ value_states[b].unsqueeze(0),
214
+ attention_mask[b].unsqueeze(0),
215
+ past_key_value,
216
+ batch_idx=b,
217
+ is_prefill=False,
218
+ )
138
219
 
139
- hidden_states = outputs[0]
140
- if batch_position >= 0:
141
- hidden_states = hidden_states[:, query_idx].unsqueeze(1)
220
+ iterate_results["key_states"].append(key_state)
221
+ iterate_results["value_states"].append(value_state)
222
+ iterate_results["attn_output"].append(attn_output)
142
223
 
143
- logits = self.lm_head(hidden_states)
224
+ key_states = torch.cat(iterate_results["key_states"], dim=0)
225
+ value_states = torch.cat(iterate_results["value_states"], dim=0)
226
+ attn_output = torch.cat(iterate_results["attn_output"], dim=0)
227
+ # Prefill & Decoder (bsz == 1)
228
+ else:
229
+ attn_output, key_states, value_states = DecoderOnlyAttention._attn(
230
+ self,
231
+ query_states,
232
+ key_states,
233
+ value_states,
234
+ attention_mask,
235
+ past_key_value,
236
+ batch_idx=batch_index,
237
+ is_prefill=True,
238
+ )
144
239
 
145
- output = (logits,) + outputs[1:]
240
+ attn_output = self.o_proj(attn_output)
146
241
 
147
- return output, batch_position + query_idx
242
+ if not output_attentions:
243
+ attn_weight = None
148
244
 
245
+ return attn_output, attn_weight, key_states, value_states
149
246
 
150
- class DecoderOnlyAttention:
247
+
248
+ class DecoderOnlyFlashAttention:
151
249
  def forward(
152
250
  self,
153
251
  hidden_states: torch.Tensor,
154
252
  attention_mask: Optional[torch.Tensor] = None,
253
+ position_ids: Optional[torch.LongTensor] = None,
155
254
  past_key_value: Optional[RebelDynamicCache] = None,
156
- batch_index: Optional[int] = None,
255
+ batch_index: Optional[torch.Tensor] = None,
157
256
  output_attentions: bool = False,
158
257
  cos: Optional[torch.Tensor] = None,
159
258
  sin: Optional[torch.Tensor] = None,
259
+ cache_pos_for_partitions: Optional[torch.Tensor] = None,
260
+ kvcache_partition_size: Optional[torch.Tensor] = None,
160
261
  **kwargs,
161
262
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
162
263
  bsz, q_len, _ = hidden_states.size()
163
-
164
264
  query_states = self.q_proj(hidden_states)
165
265
  key_states = self.k_proj(hidden_states)
166
266
  value_states = self.v_proj(hidden_states)
@@ -171,8 +271,8 @@ class DecoderOnlyAttention:
171
271
 
172
272
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
173
273
 
174
- # Decoder
175
- if (batch_index is None or batch_index == -1) and bsz > 1:
274
+ # Decoder (bsz > 1)
275
+ if bsz > 1:
176
276
  all_key_states = []
177
277
  all_value_states = []
178
278
  all_attn_output = []
@@ -196,25 +296,21 @@ class DecoderOnlyAttention:
196
296
  self.head_dim,
197
297
  )
198
298
 
199
- key_state, value_state = past_key_value.update(
299
+ # RBLN custom flash attention(decode), dummy batch index
300
+ sidx = cache_pos_for_partitions[b][0]
301
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
302
+ query_state,
200
303
  key_state,
201
304
  value_state,
202
- self.layer_idx,
203
- b,
305
+ attn_mask,
306
+ past_key_value.key_cache[self.layer_idx].unsqueeze(2),
307
+ past_key_value.value_cache[self.layer_idx].unsqueeze(2),
308
+ sidx,
309
+ kvcache_partition_size,
204
310
  )
205
311
 
206
- # reshape for removing repeat_kv
207
- attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
208
-
209
- attn_weight = attn_weight + attn_mask
210
-
211
- # upcast attention to fp32
212
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
213
- attn_output = torch.matmul(attn_weight, value_state)
214
-
215
312
  # reshape for removing repeat_kv
216
313
  attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
217
-
218
314
  attn_output = attn_output.transpose(1, 2).contiguous()
219
315
  attn_output = attn_output.reshape(1, q_len, self.num_heads * self.head_dim)
220
316
 
@@ -227,9 +323,6 @@ class DecoderOnlyAttention:
227
323
  attn_output = torch.cat(all_attn_output, dim=0)
228
324
 
229
325
  else:
230
- if batch_index is None or batch_index == -1:
231
- batch_index = 0
232
-
233
326
  # reshape for removing repeat_kv
234
327
  key_states = key_states.unsqueeze(2)
235
328
  value_states = value_states.unsqueeze(2)
@@ -242,21 +335,22 @@ class DecoderOnlyAttention:
242
335
  self.head_dim,
243
336
  )
244
337
 
245
- key_states, value_states = past_key_value.update(
338
+ assert batch_index.dim() == 0
339
+ assert not output_attentions
340
+ bidx = batch_index
341
+ sidx = cache_pos_for_partitions[0][0]
342
+ attn_output, key_states, value_states = torch.ops.rbln_custom_ops.flash_attn_prefill(
343
+ query_states,
246
344
  key_states,
247
345
  value_states,
248
- self.layer_idx,
249
- batch_index,
250
- read_first_step=True,
346
+ attention_mask,
347
+ past_key_value.key_cache[self.layer_idx].unsqueeze(2),
348
+ past_key_value.value_cache[self.layer_idx].unsqueeze(2),
349
+ bidx,
350
+ sidx,
351
+ kvcache_partition_size,
251
352
  )
252
353
 
253
- attn_weight = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
254
- attn_weight = attn_weight + attention_mask
255
-
256
- # upcast attention to fp32
257
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
258
- attn_output = torch.matmul(attn_weight, value_states)
259
-
260
354
  # reshape for removing repeat_kv
261
355
  attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
262
356
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -270,6 +364,128 @@ class DecoderOnlyAttention:
270
364
  return attn_output, attn_weight, key_states, value_states
271
365
 
272
366
 
367
+ DECODERONLY_ATTENTION_CLASSES = {
368
+ "eager": DecoderOnlyAttention,
369
+ "flash_attn_rbln": DecoderOnlyFlashAttention,
370
+ # "sdpa": DecoderOnlySdpaAttention,
371
+ }
372
+
373
+
374
+ class DecoderOnlyWrapper(torch.nn.Module):
375
+ def __init__(self, model, max_seq_len, kvcache_partition_len=None):
376
+ super().__init__()
377
+ self.config = model.config
378
+ self.model = model.model
379
+ self.lm_head = model.lm_head
380
+ self.max_seq_len = max_seq_len
381
+ self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
382
+
383
+ if kvcache_partition_len is not None:
384
+ # WORKAROUND : for passing partition length as a value to the rbln compiler.
385
+ # What is actually used is the shape of this tensor.
386
+ self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
387
+ self.attn_implementation = "flash_attn_rbln"
388
+ logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
389
+ else:
390
+ self.kvcache_partition_size = None
391
+ self.attn_implementation = "eager"
392
+
393
+ def get_forward_dict(self):
394
+ forward_dict = {
395
+ "wrapper": DecoderOnlyModel.forward,
396
+ "model": DecoderOnlyDecoderLayer.forward,
397
+ "decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
398
+ }
399
+ return forward_dict
400
+
401
+ def forward(
402
+ self,
403
+ input_ids_or_inputs_embeds,
404
+ attention_mask,
405
+ cache_position,
406
+ batch_position,
407
+ query_idx,
408
+ *past_key_values,
409
+ ):
410
+ if input_ids_or_inputs_embeds.ndim == 2:
411
+ # input_ids
412
+ input_ids = input_ids_or_inputs_embeds
413
+ inputs_embeds = None
414
+ elif input_ids_or_inputs_embeds.ndim == 3:
415
+ # inputs_embeds
416
+ input_ids = None
417
+ inputs_embeds = input_ids_or_inputs_embeds
418
+ else:
419
+ raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
420
+
421
+ # Formatting list of past_kv to DynamicCache class.
422
+ past_key_values = RebelDynamicCache.from_input_format(
423
+ cache_position,
424
+ self.config.num_hidden_layers,
425
+ *past_key_values,
426
+ )
427
+
428
+ batch_size = input_ids_or_inputs_embeds.size()[0]
429
+ seq_len = input_ids_or_inputs_embeds.size()[1]
430
+
431
+ if self.attn_implementation == "eager":
432
+ cache_pos_for_partitions = None
433
+ elif self.attn_implementation == "flash_attn_rbln":
434
+ p_len = self.kvcache_partition_size.size()[0]
435
+ num_partition = self.max_seq_len // p_len
436
+ if self.max_seq_len % p_len > 0:
437
+ raise ValueError(
438
+ f"The partition length({p_len}) must be exactly divisible by the max_seq_len({self.max_seq_len})."
439
+ )
440
+ cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
441
+
442
+ if batch_size > 1: # decode
443
+ for b_idx in range(batch_size):
444
+ decoding_step = cache_position[b_idx]
445
+ cache_pos = decoding_step
446
+ for p_idx in range(num_partition):
447
+ input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
448
+ input_1 = torch.tensor(p_len, dtype=torch.int32)
449
+ min = torch.minimum(input_0, input_1)
450
+ cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
451
+ cache_pos_for_partitions[b_idx][p_idx] = cache_pos_for_partition
452
+ else: # prefill
453
+ cache_pos = cache_position[0][0]
454
+ for p_idx in range(num_partition):
455
+ input_0 = torch.tensor(cache_pos - p_len * p_idx, dtype=torch.int32)
456
+ input_1 = torch.tensor(p_len, dtype=torch.int32)
457
+ min = torch.minimum(input_0, input_1)
458
+ cache_pos_for_partition = torch.maximum(min, torch.tensor(0, dtype=torch.int32))
459
+ cache_pos_for_partitions[0][p_idx] = cache_pos_for_partition
460
+ else:
461
+ raise NotImplementedError(f"Unknown attn_implementation: {self.attn_implementation}")
462
+
463
+ forward_dict = self.get_forward_dict()
464
+ outputs = forward_dict["wrapper"](
465
+ self.model,
466
+ input_ids=input_ids,
467
+ inputs_embeds=inputs_embeds,
468
+ attention_mask=attention_mask,
469
+ position_ids=cache_position,
470
+ past_key_values=past_key_values,
471
+ batch_ids=batch_position,
472
+ rotary_pos_emb=self.rotary_emb,
473
+ cache_pos_for_partitions=cache_pos_for_partitions,
474
+ kvcache_partition_size=self.kvcache_partition_size,
475
+ forward_dict=forward_dict,
476
+ )
477
+
478
+ hidden_states = outputs[0]
479
+ if seq_len != 1:
480
+ hidden_states = hidden_states[:, query_idx.to(torch.int).unsqueeze(0)]
481
+
482
+ logits = self.lm_head(hidden_states)
483
+
484
+ output = (logits,) + outputs[1:]
485
+
486
+ return output, batch_position + query_idx
487
+
488
+
273
489
  class DecoderOnlyDecoderLayer:
274
490
  def forward(
275
491
  self,
@@ -280,9 +496,11 @@ class DecoderOnlyDecoderLayer:
280
496
  past_key_value: Optional[RebelDynamicCache] = None,
281
497
  output_attentions: Optional[bool] = None,
282
498
  use_cache: Optional[bool] = None,
283
- batch_ids: Optional[torch.LongTensor] = None,
499
+ batch_ids: Optional[torch.Tensor] = None,
284
500
  cos: Optional[torch.Tensor] = None,
285
501
  sin: Optional[torch.Tensor] = None,
502
+ cache_pos_for_partitions: Optional[torch.Tensor] = None,
503
+ kvcache_partition_size: Optional[torch.Tensor] = None,
286
504
  forward_dict: Optional[Dict[str, classmethod]] = None,
287
505
  **kwargs,
288
506
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
@@ -301,6 +519,8 @@ class DecoderOnlyDecoderLayer:
301
519
  use_cache=use_cache,
302
520
  cos=cos,
303
521
  sin=sin,
522
+ cache_pos_for_partitions=cache_pos_for_partitions,
523
+ kvcache_partition_size=kvcache_partition_size,
304
524
  **kwargs,
305
525
  )
306
526
  past_key_value.assign(k, v, layer_idx)
@@ -331,11 +551,13 @@ class DecoderOnlyModel:
331
551
  attention_mask: Optional[torch.Tensor] = None,
332
552
  position_ids: Optional[torch.LongTensor] = None,
333
553
  past_key_values: Optional[RebelDynamicCache] = None,
334
- batch_ids: Optional[torch.LongTensor] = None,
554
+ batch_ids: Optional[torch.Tensor] = None,
335
555
  inputs_embeds: Optional[torch.FloatTensor] = None,
336
556
  use_cache: Optional[bool] = True,
337
557
  output_attentions: Optional[bool] = False,
338
558
  output_hidden_states: Optional[bool] = False,
559
+ cache_pos_for_partitions: Optional[torch.Tensor] = None,
560
+ kvcache_partition_size: Optional[torch.Tensor] = None,
339
561
  forward_dict: Optional[Dict[str, classmethod]] = None,
340
562
  rotary_pos_emb=None,
341
563
  ) -> BaseModelOutputWithPast:
@@ -375,6 +597,8 @@ class DecoderOnlyModel:
375
597
  batch_ids=batch_ids,
376
598
  cos=cos,
377
599
  sin=sin,
600
+ cache_pos_for_partitions=cache_pos_for_partitions,
601
+ kvcache_partition_size=kvcache_partition_size,
378
602
  forward_dict=forward_dict,
379
603
  )
380
604
 
@@ -435,106 +659,40 @@ def apply_rotary_pos_emb(q, k, cos, sin):
435
659
 
436
660
 
437
661
  class RotaryEmbedding(nn.Module):
662
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
663
+
438
664
  def __init__(
439
665
  self,
440
- dim,
441
- max_position_embeddings=2048,
442
- base=10000,
443
- device=None,
444
- scaling_factor=1.0,
666
+ config: PretrainedConfig,
667
+ max_seq_len_cached: int,
445
668
  ):
446
669
  super().__init__()
447
670
 
448
- self.scaling_factor = scaling_factor
449
- self.dim = dim
450
- self.max_position_embeddings = max_position_embeddings
451
- self.base = base
452
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
453
- self.register_buffer("inv_freq", inv_freq, persistent=False)
671
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
672
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
673
+ else:
674
+ rope_type = "default"
454
675
 
455
- # Build here to make `torch.jit.trace` work.
456
- device = self.inv_freq.device
676
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
677
+ position_ids = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
678
+ position_ids_expanded = position_ids[:, None]
679
+
680
+ if rope_type == "dynamic":
681
+ freqs = position_ids_expanded.float() * inv_freq.float()
682
+ else:
683
+ inv_freq_expanded = inv_freq[None, :]
684
+ freqs = position_ids_expanded.float() @ inv_freq_expanded.float()
457
685
 
458
- positions_ids = torch.arange(self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype)
459
- freqs = torch.outer(positions_ids, self.inv_freq)
460
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
461
686
  emb = torch.cat((freqs, freqs), dim=-1)
462
687
 
463
- self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
464
- self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
688
+ cos = emb.cos() * attention_scaling
689
+ sin = emb.sin() * attention_scaling
690
+
691
+ self.register_buffer("_cos_cached", cos, persistent=False)
692
+ self.register_buffer("_sin_cached", sin, persistent=False)
465
693
 
466
694
  def forward(self, x, seq_len):
467
695
  return (
468
696
  self._cos_cached[:seq_len].to(dtype=x.dtype),
469
697
  self._sin_cached[:seq_len].to(dtype=x.dtype),
470
698
  )
471
-
472
-
473
- class LinearScalingRotaryEmbedding(RotaryEmbedding):
474
- """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
475
-
476
- def __init__(
477
- self,
478
- dim,
479
- max_position_embeddings=2048,
480
- base=10000,
481
- device=None,
482
- scaling_factor=1.0,
483
- max_seq_len=2048,
484
- ):
485
- super().__init__(
486
- dim,
487
- max_position_embeddings=max_position_embeddings,
488
- base=base,
489
- scaling_factor=scaling_factor,
490
- )
491
- # difference to the original RoPE: a scaling factor is aplied to the position ids
492
- if max_seq_len > max_position_embeddings:
493
- positions_ids = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
494
- positions_ids = positions_ids / self.scaling_factor
495
- freqs = torch.outer(positions_ids, self.inv_freq)
496
- emb = torch.cat((freqs, freqs), dim=-1)
497
- cos = emb.cos()
498
- sin = emb.sin()
499
-
500
- self._cos_cached = torch.cat([self._cos_cached, cos[max_position_embeddings:]], dim=0)
501
- self._sin_cached = torch.cat([self._sin_cached, sin[max_position_embeddings:]], dim=0)
502
-
503
-
504
- class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
505
- """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
506
-
507
- def __init__(
508
- self,
509
- dim,
510
- max_position_embeddings=2048,
511
- base=10000,
512
- device=None,
513
- scaling_factor=1.0,
514
- max_seq_len=2048,
515
- ):
516
- super().__init__(
517
- dim,
518
- max_position_embeddings=max_position_embeddings,
519
- base=base,
520
- scaling_factor=scaling_factor,
521
- )
522
- # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
523
- device = self.inv_freq.device
524
- dtype = self.inv_freq.dtype
525
- if max_seq_len > max_position_embeddings:
526
- position_ids = torch.arange(max_position_embeddings, max_seq_len, dtype=dtype).view(-1, 1)
527
- seq_len = position_ids + 1
528
- base = self.base * (
529
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
530
- ) ** (self.dim / (self.dim - 2))
531
-
532
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
533
-
534
- freqs = position_ids * inv_freq
535
- emb = torch.cat((freqs, freqs), dim=-1)
536
- cos = emb.cos()
537
- sin = emb.sin()
538
-
539
- self._cos_cached = torch.cat([self._cos_cached, cos], dim=0)
540
- self._sin_cached = torch.cat([self._sin_cached, sin], dim=0)