optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -22,74 +22,209 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import math
25
- from typing import Dict, Optional, Tuple
25
+ from typing import List, Optional, Tuple
26
26
 
27
27
  import torch
28
28
  from torch import nn
29
- from transformers.modeling_outputs import (
30
- BaseModelOutputWithPast,
29
+ from transformers import PretrainedConfig, PreTrainedModel
30
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
31
+
32
+ from ....utils import logging
33
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
34
+
35
+
36
+ if is_torch_greater_or_equal_than_2_4:
37
+ register_fake = torch.library.register_fake
38
+ else:
39
+ register_fake = torch.library.impl_abstract
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+ """
44
+ ##############################################################################
45
+ # RBLN custom operation (python interface)
46
+ # torch.compile custom operation
47
+ # torch.library.define - kernel declaration
48
+ # torch.library.impl - kernel implementation
49
+ # torch.library.impl_abstract - symbolic trace
50
+ ##############################################################################
51
+ """
52
+
53
+ # RBLN custom op(flash attention decode)
54
+ torch.library.define(
55
+ "rbln_custom_ops::flash_attn_decode",
56
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
31
57
  )
32
58
 
33
- from ...cache_utils import RebelDynamicCache
34
59
 
60
+ @torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
61
+ def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, partition):
62
+ """
63
+ WORKAROUND:
64
+ Partition is declared as an argument to the function, even though it is
65
+ not actually used in the CPU implementation, this allows the rbln compiler
66
+ to perform flash attention operations with partition as an argument.
67
+ """
68
+ assert kcache.dim() == k.dim()
69
+ assert vcache.dim() == v.dim()
70
+ assert k.size(-2) == v.size(-2)
71
+ assert partition.dim() == 1
72
+ b = 0
73
+ if seq.dim() == 1:
74
+ s = seq[0]
75
+ elif seq.dim() == 0:
76
+ s = seq
77
+ else:
78
+ assert False
79
+ e = s + k.size(-2)
80
+ updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
81
+ updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
82
+ attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
83
+ attn_weight = attn_weight + mask
84
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
85
+ attn_output = torch.matmul(attn_weight, updated_v)
86
+ return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
87
+
88
+
89
+ @register_fake("rbln_custom_ops::flash_attn_decode")
90
+ def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
91
+ return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
92
+
93
+
94
+ # RBLN custom op(flash attention prefill)
95
+ torch.library.define(
96
+ "rbln_custom_ops::flash_attn_prefill",
97
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
98
+ )
99
+
100
+
101
+ @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
102
+ def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, partition):
103
+ """
104
+ WORKAROUND:
105
+ Partition is declared as an argument to the function, even though it is
106
+ not actually used in the CPU implementation, this allows the rbln compiler
107
+ to perform flash attention operations with partition as an argument.
108
+ """
109
+ assert kcache.dim() == k.dim()
110
+ assert vcache.dim() == v.dim()
111
+ assert k.size(-2) == v.size(-2)
112
+ assert partition.dim() == 1
113
+ if batch.dim() == 1:
114
+ b = batch[0]
115
+ elif batch.dim() == 0:
116
+ b = batch
117
+ else:
118
+ assert False
119
+ if seq.dim() == 1:
120
+ s = seq[0]
121
+ elif seq.dim() == 0:
122
+ s = seq
123
+ else:
124
+ assert False
125
+ e = s + k.size(-2)
126
+ updated_k = kcache[b].unsqueeze(0).slice_scatter(k, dim=-2, start=s, end=e)
127
+ updated_v = vcache[b].unsqueeze(0).slice_scatter(v, dim=-2, start=s, end=e)
128
+ attn_weight = torch.matmul(q, updated_k.transpose(3, 4)) / math.sqrt(128)
129
+ attn_weight = attn_weight + mask
130
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
131
+ attn_output = torch.matmul(attn_weight, updated_v)
132
+ return attn_output, torch.empty_like(kcache), torch.empty_like(vcache)
133
+
134
+
135
+ @register_fake("rbln_custom_ops::flash_attn_prefill")
136
+ def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
137
+ return torch.empty_like(q), torch.empty_like(kcache), torch.empty_like(vcache)
138
+
139
+
140
+ # RBLN custom op(cache update)
141
+ torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
142
+
143
+
144
+ @torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
145
+ def rbln_cache_update_cpu(cache, value, batch, seq):
146
+ updated_cache = cache[batch].slice_scatter(value, dim=-2, start=batch[0], end=batch[0] + seq[0])
147
+ return updated_cache
148
+
149
+
150
+ @register_fake("rbln_custom_ops::rbln_cache_update")
151
+ def rbln_cache_update_abstract(cache, value, batch, seq):
152
+ return torch.empty_like(cache)
153
+
154
+
155
+ class DecoderOnlyWrapper(nn.Module):
156
+ """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
157
+
158
+ This wrapper is designed to:
159
+ 1. Convert Huggingface decoder models for RBLN compilation with static shapes
160
+ 2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
161
+ 3. Manage different attention implementations (standard and flash attention)
162
+ 4. Support both prefill and decode phases
163
+
164
+ Notes:
165
+ - Wrapper must only receive positional arguments in forward() due to torch.jit.trace dependency
166
+ - Wrapper should not contain neural network graph operations (including memory view handling)
167
+
168
+ Args:
169
+ causal_lm (PreTrainedModel): The Huggingface causal language model to wrap
170
+ max_seq_len (int): Maximum sequence length for position embeddings and cache sizes
171
+ use_rotary_emb (bool): Whether to use rotary position embeddings
172
+ kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
173
+ If provided, uses flash attention; if None, uses standard attention
174
+ """
35
175
 
36
- class DecoderOnlyWrapper(torch.nn.Module):
37
- def __init__(self, model, max_seq_len):
176
+ def __init__(self, causal_lm: PreTrainedModel, max_seq_len, use_rotary_emb: bool, kvcache_partition_len=None):
38
177
  super().__init__()
39
- self.config = model.config
40
- self.model = model.model
41
- self.lm_head = model.lm_head
42
-
43
- self.head_dim = (
44
- self.config.head_dim
45
- if hasattr(self.config, "head_dim")
46
- else self.config.hidden_size // self.config.num_attention_heads
47
- )
48
- self.max_position_embeddings = (
49
- self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
50
- )
51
- self.max_seq_len = max_seq_len
52
- self.rope_scaling = getattr(self.config, "rope_scaling", None)
53
- self.rotary_emb = self._init_rope()
54
-
55
- def _init_rope(self):
56
- if self.rope_scaling is None:
57
- rotary_emb = RotaryEmbedding(
58
- self.head_dim,
59
- max_position_embeddings=self.max_position_embeddings,
60
- base=self.config.rope_theta,
61
- )
178
+ self.config = causal_lm.config
179
+
180
+ if use_rotary_emb:
181
+ self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
62
182
  else:
63
- scaling_type = self.rope_scaling["type"]
64
- scaling_factor = self.rope_scaling["factor"]
65
- if scaling_type == "linear":
66
- rotary_emb = LinearScalingRotaryEmbedding(
67
- self.head_dim,
68
- max_position_embeddings=self.max_position_embeddings,
69
- scaling_factor=scaling_factor,
70
- base=self.config.rope_theta,
71
- max_seq_len=self.max_seq_len,
72
- )
73
- elif scaling_type == "dynamic":
74
- rotary_emb = DynamicNTKScalingRotaryEmbedding(
75
- self.head_dim,
76
- max_position_embeddings=self.max_position_embeddings,
77
- scaling_factor=scaling_factor,
78
- base=self.config.rope_theta,
79
- max_seq_len=self.max_seq_len,
183
+ self.rotary_emb = None
184
+
185
+ if kvcache_partition_len is not None:
186
+ # WORKAROUND : for passing partition length as a value to the rbln compiler.
187
+ # What is actually used is the shape of this tensor.
188
+ self.attn_impl = "flash_attn"
189
+ logger.info(f"Using flash-attention. (partition length : {kvcache_partition_len})")
190
+ else:
191
+ self.attn_impl = "eager"
192
+ self.kvcache_partition_len = kvcache_partition_len
193
+
194
+ self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm)
195
+
196
+ self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
197
+ self._phase = "prefill"
198
+
199
+ def get_rotary_emb(self, max_seq_len):
200
+ return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
201
+
202
+ def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel):
203
+ new_layers = []
204
+ for layer in causal_lm.model.layers:
205
+ if self.attn_impl == "eager":
206
+ new_self_attn = DecoderOnlyAttention(layer.self_attn)
207
+ elif self.attn_impl == "flash_attn":
208
+ new_self_attn = DecoderOnlyFlashAttention(
209
+ layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
80
210
  )
81
211
  else:
82
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
212
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
213
+
214
+ new_layer = DecoderOnlyLayer(layer, new_self_attn)
215
+ new_layers.append(new_layer)
216
+ new_model = DecoderOnlyModel(causal_lm.model, new_layers)
217
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
218
+ return new_causal_lm
83
219
 
84
- return rotary_emb
220
+ @property
221
+ def phase(self):
222
+ return self._phase
85
223
 
86
- def get_forward_dict(self):
87
- forward_dict = {
88
- "wrapper": DecoderOnlyModel.forward,
89
- "model": DecoderOnlyDecoderLayer.forward,
90
- "decoder_layer": DecoderOnlyAttention.forward,
91
- }
92
- return forward_dict
224
+ @phase.setter
225
+ def phase(self, phase: str):
226
+ self._phase = phase
227
+ self.causal_lm.phase = phase
93
228
 
94
229
  def forward(
95
230
  self,
@@ -97,324 +232,514 @@ class DecoderOnlyWrapper(torch.nn.Module):
97
232
  attention_mask,
98
233
  cache_position,
99
234
  batch_position,
100
- query_idx,
235
+ query_position,
101
236
  *past_key_values,
102
237
  ):
103
- if input_ids_or_inputs_embeds.shape[1] == 1:
104
- rbln_batch_position = None
105
- else:
106
- rbln_batch_position = batch_position
107
-
108
238
  if input_ids_or_inputs_embeds.ndim == 2:
109
- # input_ids
239
+ # It is input_ids
110
240
  input_ids = input_ids_or_inputs_embeds
111
241
  inputs_embeds = None
112
242
  elif input_ids_or_inputs_embeds.ndim == 3:
113
- # inputs_embeds
243
+ # It is inputs_embeds
114
244
  input_ids = None
115
245
  inputs_embeds = input_ids_or_inputs_embeds
116
246
  else:
117
247
  raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
118
248
 
119
- # Formatting list of past_kv to DynamicCache class.
120
- past_key_values = RebelDynamicCache.from_input_format(
121
- cache_position,
122
- self.config.num_hidden_layers,
123
- *past_key_values,
124
- )
249
+ if len(past_key_values) != 2 * self.num_hidden_layers:
250
+ raise ValueError(
251
+ f"Different past_key_values to model's config. {len(past_key_values)} != {self.num_hidden_layers}"
252
+ )
125
253
 
126
- forward_dict = self.get_forward_dict()
127
- outputs = forward_dict["wrapper"](
128
- self.model,
254
+ seq_len = input_ids_or_inputs_embeds.shape[1]
255
+ if seq_len == 1:
256
+ self.phase = "decode"
257
+ else:
258
+ self.phase = "prefill"
259
+
260
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
261
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
262
+ _past_key_values = []
263
+ for i in range(self.config.num_hidden_layers):
264
+ key_states = past_key_values[i * 2]
265
+ value_states = past_key_values[i * 2 + 1]
266
+ past_key_value = [key_states, value_states]
267
+ _past_key_values.append(past_key_value)
268
+ past_key_values = _past_key_values
269
+
270
+ logit, present_key_values = self.causal_lm(
129
271
  input_ids=input_ids,
130
272
  inputs_embeds=inputs_embeds,
131
273
  attention_mask=attention_mask,
132
- position_ids=cache_position,
274
+ cache_position=cache_position,
275
+ batch_position=batch_position,
276
+ query_position=query_position,
133
277
  past_key_values=past_key_values,
134
- batch_ids=rbln_batch_position,
135
- rotary_pos_emb=self.rotary_emb,
136
- forward_dict=forward_dict,
278
+ rotary_emb=self.rotary_emb,
137
279
  )
138
280
 
139
- hidden_states = outputs[0]
140
- if batch_position >= 0:
141
- hidden_states = hidden_states[:, query_idx].unsqueeze(1)
281
+ # ((key, value)) * n_layer -> [key, value] * n_layer
282
+ _present_key_values = ()
283
+ for i in range(self.num_hidden_layers):
284
+ key_states = present_key_values[i][0]
285
+ value_states = present_key_values[i][1]
286
+ _present_key_values = _present_key_values + (key_states, value_states)
287
+ present_key_values = _present_key_values
142
288
 
143
- logits = self.lm_head(hidden_states)
289
+ # batch_position + query_position is dummy output node to keep the number of outputs
290
+ return logit, present_key_values, batch_position + query_position
144
291
 
145
- output = (logits,) + outputs[1:]
146
292
 
147
- return output, batch_position + query_idx
293
+ class DecoderOnlyForCausalLM(nn.Module):
294
+ """A specialized wrapper for Causal Language Models optimized for RBLN compilation.
148
295
 
296
+ This class adapts Huggingface's CausalLM (or similar models) for RBLN deployment by:
297
+ 1. Managing model phases (prefill/decode) throughout the computation graph
298
+ 2. Handling output shape alignments for static compilation
299
+ 3. Coordinating between the original model and RBLN-optimized components
300
+
301
+ The class serves as an intermediate layer between DecoderOnlyWrapper and the core model,
302
+ focusing on maintaining correct model behavior while enabling RBLN-specific optimizations.
303
+
304
+ Args:
305
+ causal_lm (PreTrainedModel): Original Huggingface causal language model
306
+ model (DecoderOnlyModel): RBLN-optimized model instance
307
+
308
+ Attributes:
309
+ config: Configuration from the original causal language model
310
+ _original_mod: Reference to the original model for components like lm_head
311
+ model: RBLN-optimized decoder model instance
312
+ _phase: Current processing phase ("prefill" or "decode")
313
+ """
314
+
315
+ def __init__(self, causal_lm: PreTrainedModel, model):
316
+ super().__init__()
317
+ self.config = causal_lm.config
318
+ self._original_mod = causal_lm
319
+ self.model = model
320
+ self._phase = "prefill"
321
+
322
+ @property
323
+ def phase(self):
324
+ return self._phase
325
+
326
+ @phase.setter
327
+ def phase(self, phase: str):
328
+ self._phase = phase
329
+ self.model.phase = phase
149
330
 
150
- class DecoderOnlyAttention:
151
331
  def forward(
152
332
  self,
153
- hidden_states: torch.Tensor,
154
- attention_mask: Optional[torch.Tensor] = None,
155
- past_key_value: Optional[RebelDynamicCache] = None,
156
- batch_index: Optional[int] = None,
157
- output_attentions: bool = False,
158
- cos: Optional[torch.Tensor] = None,
159
- sin: Optional[torch.Tensor] = None,
160
- **kwargs,
161
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
162
- bsz, q_len, _ = hidden_states.size()
333
+ input_ids: torch.Tensor = None,
334
+ inputs_embeds: torch.Tensor = None,
335
+ attention_mask: torch.Tensor = None,
336
+ cache_position: torch.Tensor = None,
337
+ batch_position: torch.Tensor = None,
338
+ query_position: torch.Tensor = None,
339
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
340
+ rotary_emb: nn.Module = None,
341
+ ):
342
+ # outputs
343
+ hidden_states, present_key_values = self.model(
344
+ input_ids=input_ids,
345
+ inputs_embeds=inputs_embeds,
346
+ attention_mask=attention_mask,
347
+ cache_position=cache_position,
348
+ batch_position=batch_position,
349
+ past_key_values=past_key_values,
350
+ rotary_emb=rotary_emb,
351
+ )
163
352
 
164
- query_states = self.q_proj(hidden_states)
165
- key_states = self.k_proj(hidden_states)
166
- value_states = self.v_proj(hidden_states)
353
+ if self.phase == "prefill":
354
+ hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
167
355
 
168
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
169
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
170
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
171
-
172
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
173
-
174
- # Decoder
175
- if (batch_index is None or batch_index == -1) and bsz > 1:
176
- all_key_states = []
177
- all_value_states = []
178
- all_attn_output = []
179
-
180
- for b in range(bsz):
181
- query_state = query_states[b].unsqueeze(0)
182
- attn_mask = attention_mask[b].unsqueeze(0)
183
- key_state = key_states[b].unsqueeze(0)
184
- value_state = value_states[b].unsqueeze(0)
185
-
186
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
187
- key_state = key_state.unsqueeze(2)
188
- value_state = value_state.unsqueeze(2)
189
- attn_mask = attn_mask.unsqueeze(2)
190
-
191
- query_state = query_state.view(
192
- 1,
193
- self.num_key_value_heads,
194
- self.num_heads // self.num_key_value_heads,
195
- q_len,
196
- self.head_dim,
197
- )
356
+ logits = self._original_mod.lm_head(hidden_states)
357
+ output = (logits, present_key_values)
358
+ return output
198
359
 
199
- key_state, value_state = past_key_value.update(
200
- key_state,
201
- value_state,
202
- self.layer_idx,
203
- b,
204
- )
205
360
 
206
- # reshape for removing repeat_kv
207
- attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
361
+ class DecoderOnlyModel(nn.Module):
362
+ """A modified decoder-only model implementation optimized for RBLN compilation.
208
363
 
209
- attn_weight = attn_weight + attn_mask
364
+ Args:
365
+ model: Original Huggingface model to adapt
366
+ layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
210
367
 
211
- # upcast attention to fp32
212
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
213
- attn_output = torch.matmul(attn_weight, value_state)
368
+ Attributes:
369
+ _original_mod: Reference to original Huggingface model
370
+ layers: ModuleList of RBLN-optimized transformer layers
371
+ _phase: Current processing phase ("prefill" or "decode")
372
+ """
214
373
 
215
- # reshape for removing repeat_kv
216
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
374
+ mask_fmin = torch.finfo(torch.float16).min
217
375
 
218
- attn_output = attn_output.transpose(1, 2).contiguous()
219
- attn_output = attn_output.reshape(1, q_len, self.num_heads * self.head_dim)
376
+ def __init__(self, model, layers: List["DecoderOnlyLayer"]):
377
+ super().__init__()
378
+ self._original_mod = model
379
+ self.layers = nn.ModuleList(layers)
380
+ self._phase = "prefill"
220
381
 
221
- all_key_states.append(key_state)
222
- all_value_states.append(value_state)
223
- all_attn_output.append(attn_output)
382
+ @property
383
+ def phase(self):
384
+ return self._phase
224
385
 
225
- key_states = torch.cat(all_key_states, dim=0)
226
- value_states = torch.cat(all_value_states, dim=0)
227
- attn_output = torch.cat(all_attn_output, dim=0)
386
+ @phase.setter
387
+ def phase(self, phase: str):
388
+ self._phase = phase
389
+ for layer in self.layers:
390
+ layer.phase = phase
228
391
 
229
- else:
230
- if batch_index is None or batch_index == -1:
231
- batch_index = 0
232
-
233
- # reshape for removing repeat_kv
234
- key_states = key_states.unsqueeze(2)
235
- value_states = value_states.unsqueeze(2)
236
- attention_mask = attention_mask.unsqueeze(2)
237
- query_states = query_states.view(
238
- 1,
239
- self.num_key_value_heads,
240
- self.num_heads // self.num_key_value_heads,
241
- q_len,
242
- self.head_dim,
392
+ @property
393
+ def hidden_multiplier(self):
394
+ return 1
395
+
396
+ def get_last_layernorm(self) -> nn.LayerNorm:
397
+ return self._original_mod.norm
398
+
399
+ def get_embedding(self) -> nn.Embedding:
400
+ return self._original_mod.embed_tokens
401
+
402
+ def get_pos_embedding(self) -> nn.Embedding:
403
+ raise NotImplementedError(
404
+ "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
405
+ )
406
+
407
+ def forward(
408
+ self,
409
+ input_ids: torch.Tensor = None,
410
+ inputs_embeds: torch.Tensor = None,
411
+ attention_mask: torch.Tensor = None,
412
+ cache_position: torch.Tensor = None,
413
+ batch_position: torch.Tensor = None,
414
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
415
+ rotary_emb: nn.Module = None,
416
+ ):
417
+ # retrieve input_ids and inputs_embeds
418
+ if (input_ids is None) ^ (inputs_embeds is not None):
419
+ raise ValueError(
420
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
243
421
  )
244
422
 
245
- key_states, value_states = past_key_value.update(
246
- key_states,
247
- value_states,
248
- self.layer_idx,
249
- batch_index,
250
- read_first_step=True,
423
+ # embed positions
424
+ if inputs_embeds is None:
425
+ inputs_embeds = self.get_embedding()(input_ids)
426
+
427
+ hidden_states = inputs_embeds * self.hidden_multiplier
428
+ attention_mask = (1 - attention_mask) * self.mask_fmin
429
+
430
+ # get cos,sin vector if needed
431
+ if rotary_emb is not None:
432
+ cos, sin = rotary_emb(hidden_states, attention_mask.shape[-1]) # dtype carrier, max_seq_len
433
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
434
+ else:
435
+ batch_size = inputs_embeds.shape[0]
436
+ if cache_position.shape[0] > 1:
437
+ position_embeds = []
438
+ for b_idx in range(batch_size):
439
+ position_embed = self.get_pos_embedding()(cache_position[b_idx])
440
+ position_embeds.append(position_embed)
441
+
442
+ position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
443
+ else:
444
+ position_embeds = self.get_pos_embedding()(cache_position)
445
+ hidden_states = hidden_states + position_embeds
446
+ cos, sin = None, None
447
+
448
+ # (batch, seq_len) -> (batch,)
449
+ current_steps = cache_position[:, 0]
450
+
451
+ present_key_values = past_key_values
452
+ for layer in self.layers:
453
+ hidden_states, present_key_values = layer(
454
+ hidden_states=hidden_states,
455
+ attention_mask=attention_mask,
456
+ current_steps=current_steps,
457
+ batch_position=batch_position,
458
+ past_key_values=present_key_values,
459
+ cos=cos,
460
+ sin=sin,
251
461
  )
252
462
 
253
- attn_weight = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
254
- attn_weight = attn_weight + attention_mask
463
+ hidden_states = self.get_last_layernorm()(hidden_states)
464
+ return hidden_states, present_key_values
255
465
 
256
- # upcast attention to fp32
257
- attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
258
- attn_output = torch.matmul(attn_weight, value_states)
259
466
 
260
- # reshape for removing repeat_kv
261
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
262
- attn_output = attn_output.transpose(1, 2).contiguous()
263
- attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
467
+ class DecoderOnlyLayer(nn.Module):
468
+ """A single transformer layer adapted for RBLN compilation with static shapes.
264
469
 
265
- attn_output = self.o_proj(attn_output)
470
+ This layer implements a modified transformer block that includes:
471
+ 1. Self-attention mechanism (either standard or flash attention)
472
+ 2. Feed-forward network (FFN)
473
+ 3. Layer normalization
474
+ 4. Residual connections
266
475
 
267
- if not output_attentions:
268
- attn_weight = None
476
+ The layer is specifically designed to:
477
+ - Support compilation to RBLN custom ops
478
+ - Maintain static tensor shapes throughout computations
479
+ - Handle both prefill and decode phases efficiently
480
+ - Manage attention state transitions properly
269
481
 
270
- return attn_output, attn_weight, key_states, value_states
482
+ Args:
483
+ layer: Original transformer layer module to wrap
484
+ self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
271
485
 
486
+ Attributes:
487
+ _original_mod: Reference to original layer for accessing components
488
+ self_attn: Modified attention mechanism mapped to RBLN ops at compile time
489
+ phase: Current operation phase ("prefill" or "decode")
490
+ """
491
+
492
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
493
+ super().__init__()
494
+ self._original_mod = layer
495
+ self.self_attn = self_attn
496
+ self._phase = "prefill"
497
+
498
+ @property
499
+ def phase(self):
500
+ return self._phase
501
+
502
+ @phase.setter
503
+ def phase(self, phase: str):
504
+ self._phase = phase
505
+ self.self_attn.phase = phase
506
+
507
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
508
+ return self._original_mod.input_layernorm
509
+
510
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
511
+ return self._original_mod.post_attention_layernorm
272
512
 
273
- class DecoderOnlyDecoderLayer:
274
513
  def forward(
275
514
  self,
276
515
  hidden_states: torch.Tensor,
277
- layer_idx: int,
278
- attention_mask: Optional[torch.Tensor] = None,
279
- position_ids: Optional[torch.LongTensor] = None,
280
- past_key_value: Optional[RebelDynamicCache] = None,
281
- output_attentions: Optional[bool] = None,
282
- use_cache: Optional[bool] = None,
283
- batch_ids: Optional[torch.LongTensor] = None,
516
+ attention_mask: torch.Tensor,
517
+ current_steps: torch.LongTensor,
518
+ batch_position: torch.Tensor,
519
+ past_key_values: Tuple[Tuple[torch.Tensor]],
284
520
  cos: Optional[torch.Tensor] = None,
285
521
  sin: Optional[torch.Tensor] = None,
286
- forward_dict: Optional[Dict[str, classmethod]] = None,
287
- **kwargs,
288
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
522
+ ):
289
523
  residual = hidden_states
290
524
 
291
- hidden_states = self.input_layernorm(hidden_states)
525
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
292
526
 
293
- hidden_states, self_attn_weight, k, v = forward_dict["decoder_layer"](
294
- self.self_attn,
527
+ hidden_states, present_key_values = self.self_attn(
295
528
  hidden_states=hidden_states,
296
529
  attention_mask=attention_mask,
297
- position_ids=position_ids,
298
- past_key_value=past_key_value,
299
- output_attentions=output_attentions,
300
- batch_index=batch_ids,
301
- use_cache=use_cache,
530
+ current_steps=current_steps,
531
+ batch_position=batch_position,
532
+ past_key_values=past_key_values,
302
533
  cos=cos,
303
534
  sin=sin,
304
- **kwargs,
305
535
  )
306
- past_key_value.assign(k, v, layer_idx)
307
-
308
536
  hidden_states = residual + hidden_states
309
537
 
310
538
  # Fully Connected
311
539
  residual = hidden_states
312
- hidden_states = self.post_attention_layernorm(hidden_states)
313
- hidden_states = self.mlp(hidden_states)
540
+ hidden_states = self.get_post_attention_layernorm()(hidden_states)
541
+ hidden_states = self._original_mod.mlp(hidden_states)
314
542
  hidden_states = residual + hidden_states
315
543
 
316
- outputs = (hidden_states,)
544
+ return hidden_states, present_key_values
317
545
 
318
- if output_attentions:
319
- outputs += (self_attn_weight,)
320
546
 
321
- if use_cache:
322
- outputs += (past_key_value,)
547
+ class DecoderOnlyAttention(nn.Module):
548
+ """Attention implementation for decoder-only models optimized for RBLN compilation.
323
549
 
324
- return outputs
550
+ This class implements a modified version of the standard attention mechanism that:
551
+ 1. Supports static shape requirements for RBLN compilation
552
+ 2. Handles explicit batch and position management
325
553
 
554
+ Args:
555
+ self_attn: Original attention module from the base model
556
+ """
326
557
 
327
- class DecoderOnlyModel:
328
- def forward(
558
+ def __init__(self, self_attn):
559
+ super().__init__()
560
+ self._original_mod = self_attn
561
+ self.layer_idx = self_attn.layer_idx
562
+ self.num_heads = self._original_mod.num_heads
563
+ self.head_dim = self._original_mod.head_dim
564
+ self.phase = "prefill"
565
+ self.__post_init__()
566
+
567
+ def __post_init__(self):
568
+ self.q_proj = self._original_mod.q_proj
569
+ self.k_proj = self._original_mod.k_proj
570
+ self.v_proj = self._original_mod.v_proj
571
+ self.o_proj = self._original_mod.o_proj
572
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
573
+
574
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
575
+ """Projects input hidden states into query, key, and value representations.
576
+
577
+ Args:
578
+ hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
579
+
580
+ Returns:
581
+ Tuple of (query_states, key_states, value_states)
582
+ """
583
+ query_states = self.q_proj(hidden_states)
584
+ key_states = self.k_proj(hidden_states)
585
+ value_states = self.v_proj(hidden_states)
586
+ return query_states, key_states, value_states
587
+
588
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
589
+ return apply_rotary_pos_emb(query_states, key_states, cos, sin)
590
+
591
+ def rbln_attention(
329
592
  self,
330
- input_ids: torch.LongTensor = None,
331
- attention_mask: Optional[torch.Tensor] = None,
332
- position_ids: Optional[torch.LongTensor] = None,
333
- past_key_values: Optional[RebelDynamicCache] = None,
334
- batch_ids: Optional[torch.LongTensor] = None,
335
- inputs_embeds: Optional[torch.FloatTensor] = None,
336
- use_cache: Optional[bool] = True,
337
- output_attentions: Optional[bool] = False,
338
- output_hidden_states: Optional[bool] = False,
339
- forward_dict: Optional[Dict[str, classmethod]] = None,
340
- rotary_pos_emb=None,
341
- ) -> BaseModelOutputWithPast:
342
- # retrieve input_ids and inputs_embeds
343
- if (input_ids is None) ^ (inputs_embeds is not None):
344
- raise ValueError(
345
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
346
- )
593
+ query_state,
594
+ key_state,
595
+ value_state,
596
+ attn_mask,
597
+ batch_idx,
598
+ past_key_state,
599
+ past_value_state,
600
+ current_step,
601
+ # below are designed for Midm, GPT which requires to support scaling for attention weights
602
+ # TODO(jongho): Merge and manage scales generally
603
+ layer_idx=None,
604
+ scale_attn_weights: bool = None,
605
+ scale_attn_by_inverse_layer_idx: bool = None,
606
+ scale_qk_by_inverse_layer_idx: bool = None,
607
+ ):
608
+ """Compute attention with static shapes and explicit cache management.
609
+
610
+ Args:
611
+ query_state: Query tensor [1, num_heads, 1, head_dim]
612
+ key_state: Key tensor [1, num_heads, seq_len, head_dim]
613
+ value_state: Value tensor [1, num_heads, seq_len, head_dim]
614
+ attn_mask: Attention mask tensor
615
+ batch_idx: Batch index for cache lookup
616
+ past_key_state: Previous key cache states
617
+ past_value_state: Previous value cache states
618
+ current_step: Current position in sequence
619
+
620
+ Returns:
621
+ Tuple of (attention_output, key_state, value_state)
622
+ """
623
+ # Implementation details.
624
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
625
+ key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
626
+ value_state = value_state.unsqueeze(2)
627
+ attn_mask = attn_mask.unsqueeze(2)
628
+
629
+ query_state = query_state.view(
630
+ 1,
631
+ self.num_key_value_heads,
632
+ self.num_heads // self.num_key_value_heads,
633
+ -1, # seq len
634
+ self.head_dim,
635
+ ) #
636
+
637
+ kend = current_step + key_state.shape[-2]
638
+ vend = current_step + value_state.shape[-2]
639
+
640
+ key_state = (
641
+ past_key_state[batch_idx]
642
+ .unsqueeze(0)
643
+ .unsqueeze(2)
644
+ .slice_scatter(key_state, dim=-2, start=current_step, end=kend)
645
+ )
646
+ value_state = (
647
+ past_value_state[batch_idx]
648
+ .unsqueeze(0)
649
+ .unsqueeze(2)
650
+ .slice_scatter(value_state, dim=-2, start=current_step, end=vend)
651
+ )
347
652
 
348
- # embed positions
349
- if inputs_embeds is None:
350
- inputs_embeds = self.embed_tokens(input_ids)
351
-
352
- hidden_states = inputs_embeds
353
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
354
-
355
- # get cos,sin vector
356
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
357
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
358
-
359
- # decoder layers
360
- all_hidden_states = () if output_hidden_states else None
361
- all_self_attns = () if output_attentions else None
362
-
363
- for layer_idx, decoder_layer in enumerate(self.layers):
364
- if output_hidden_states:
365
- all_hidden_states += (hidden_states,)
366
- layer_outputs = forward_dict["model"](
367
- decoder_layer,
368
- hidden_states,
369
- layer_idx,
370
- attention_mask=attention_mask,
371
- position_ids=position_ids,
372
- past_key_value=past_key_values,
373
- output_attentions=output_attentions,
374
- use_cache=use_cache,
375
- batch_ids=batch_ids,
376
- cos=cos,
377
- sin=sin,
378
- forward_dict=forward_dict,
379
- )
653
+ attn_weight = torch.matmul(query_state, key_state.transpose(3, 4))
654
+ attn_weight = attn_weight / math.sqrt(self.head_dim)
655
+
656
+ if layer_idx is not None and (scale_attn_by_inverse_layer_idx or scale_qk_by_inverse_layer_idx):
657
+ attn_weight = attn_weight / float(layer_idx + 1)
658
+
659
+ attn_weight += attn_mask
660
+
661
+ if layer_idx is not None and scale_qk_by_inverse_layer_idx:
662
+ attn_weight = attn_weight * float(layer_idx + 1)
380
663
 
381
- hidden_states = layer_outputs[0]
664
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1)
382
665
 
383
- updated_cache = layer_outputs[2 if output_attentions else 1]
666
+ attn_output = torch.matmul(attn_weight, value_state)
384
667
 
385
- if output_attentions:
386
- all_self_attns += (layer_outputs[1],)
668
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
669
+ attn_output = attn_output.transpose(1, 2).contiguous()
670
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
387
671
 
388
- hidden_states = self.norm(hidden_states)
672
+ return attn_output, key_state, value_state
389
673
 
390
- # add hidden states from the last decoder layer
391
- if output_hidden_states:
392
- all_hidden_states += (hidden_states,)
674
+ def forward(
675
+ self,
676
+ hidden_states: torch.Tensor,
677
+ attention_mask: torch.Tensor,
678
+ current_steps: torch.LongTensor,
679
+ batch_position: torch.Tensor,
680
+ past_key_values: Tuple[Tuple[torch.Tensor]],
681
+ cos: Optional[torch.Tensor] = None, # (batch, 1, prefill_size, head_dim)
682
+ sin: Optional[torch.Tensor] = None,
683
+ ):
684
+ batch_size, query_length, _ = hidden_states.size()
393
685
 
394
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
395
- next_cache = updated_cache.to_legacy_cache()
686
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
396
687
 
397
- return BaseModelOutputWithPast(
398
- last_hidden_state=hidden_states,
399
- past_key_values=next_cache,
400
- hidden_states=all_hidden_states,
401
- attentions=all_self_attns,
688
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
689
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
690
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
691
+ 1, 2
402
692
  )
693
+ # b, num_head, query, head_dim
694
+
695
+ if cos is not None and sin is not None:
696
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
697
+
698
+ if batch_size > 1 and self.phase == "prefill":
699
+ raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
700
+
701
+ _key_states = []
702
+ _value_states = []
703
+ _attn_outputs = []
704
+ for b in range(batch_size):
705
+ current_step = current_steps[b]
706
+ attn_output, key_state, value_state = self.rbln_attention(
707
+ query_states[b].unsqueeze(0),
708
+ key_states[b].unsqueeze(0),
709
+ value_states[b].unsqueeze(0),
710
+ attention_mask[b].unsqueeze(0)
711
+ if self.phase == "decode"
712
+ else attention_mask, # TODO(jongho): fix when msoftmax is supported
713
+ past_key_state=past_key_values[self.layer_idx][0],
714
+ past_value_state=past_key_values[self.layer_idx][1],
715
+ batch_idx=b if self.phase == "decode" else batch_position,
716
+ current_step=current_step,
717
+ )
718
+ _key_states.append(key_state)
719
+ _value_states.append(value_state)
720
+ _attn_outputs.append(attn_output)
721
+ key_states = torch.cat(_key_states, dim=0)
722
+ value_states = torch.cat(_value_states, dim=0)
723
+ attn_outputs = torch.cat(_attn_outputs, dim=0)
724
+
725
+ attn_outputs = self.o_proj(attn_outputs)
726
+ past_key_values[self.layer_idx] = key_states, value_states
727
+ return attn_outputs, past_key_values
403
728
 
404
729
 
405
- def slice_and_unsqueeze_cos_sin(cos, sin, position_ids, unsqueeze_dim=1):
406
- """Slice cos[position_ids], sin[position_ids] vector for the query."""
407
- if position_ids.shape[0] > 1:
730
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
731
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
732
+ if cache_position.shape[0] > 1:
408
733
  cos_all = []
409
734
  sin_all = []
410
- for i in range(position_ids.shape[0]):
411
- cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
412
- sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
735
+ for i in range(cache_position.shape[0]):
736
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
737
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
413
738
  cos = torch.cat(cos_all, dim=0)
414
739
  sin = torch.cat(sin_all, dim=0)
415
740
  else:
416
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
417
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
741
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
742
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
418
743
 
419
744
  return cos, sin
420
745
 
@@ -434,34 +759,58 @@ def apply_rotary_pos_emb(q, k, cos, sin):
434
759
  return q_embed, k_embed
435
760
 
436
761
 
762
+ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
763
+ # Partial rotary embedding
764
+ query_rot, query_pass = (
765
+ query_states[..., :ndim],
766
+ query_states[..., ndim:],
767
+ )
768
+ key_rot, key_pass = (
769
+ key_states[..., :ndim],
770
+ key_states[..., ndim:],
771
+ )
772
+
773
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
774
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
775
+
776
+ # [batch_size, seq_length, num_heads, head_dim]
777
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
778
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
779
+ return query_states, key_states
780
+
781
+
437
782
  class RotaryEmbedding(nn.Module):
783
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
784
+
438
785
  def __init__(
439
786
  self,
440
- dim,
441
- max_position_embeddings=2048,
442
- base=10000,
443
- device=None,
444
- scaling_factor=1.0,
787
+ config: PretrainedConfig,
788
+ max_seq_len_cached: int,
445
789
  ):
446
790
  super().__init__()
447
791
 
448
- self.scaling_factor = scaling_factor
449
- self.dim = dim
450
- self.max_position_embeddings = max_position_embeddings
451
- self.base = base
452
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
453
- self.register_buffer("inv_freq", inv_freq, persistent=False)
792
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
793
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
794
+ else:
795
+ rope_type = "default"
796
+
797
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
798
+ cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
799
+ cache_position_expanded = cache_position[:, None]
454
800
 
455
- # Build here to make `torch.jit.trace` work.
456
- device = self.inv_freq.device
801
+ if rope_type == "dynamic":
802
+ freqs = cache_position_expanded.float() * inv_freq.float()
803
+ else:
804
+ inv_freq_expanded = inv_freq[None, :]
805
+ freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
457
806
 
458
- positions_ids = torch.arange(self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype)
459
- freqs = torch.outer(positions_ids, self.inv_freq)
460
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
461
807
  emb = torch.cat((freqs, freqs), dim=-1)
462
808
 
463
- self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
464
- self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
809
+ cos = emb.cos() * attention_scaling
810
+ sin = emb.sin() * attention_scaling
811
+
812
+ self.register_buffer("_cos_cached", cos, persistent=False)
813
+ self.register_buffer("_sin_cached", sin, persistent=False)
465
814
 
466
815
  def forward(self, x, seq_len):
467
816
  return (
@@ -470,71 +819,140 @@ class RotaryEmbedding(nn.Module):
470
819
  )
471
820
 
472
821
 
473
- class LinearScalingRotaryEmbedding(RotaryEmbedding):
474
- """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
475
-
476
- def __init__(
822
+ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
823
+ def __init__(self, self_attn, kvcache_partition_len):
824
+ super().__init__(self_attn=self_attn)
825
+ self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
826
+
827
+ def get_cache_pos_for_partitions(self, current_steps, batch_size, max_seq_len):
828
+ partition_len = self.kvcache_partition_size.size()[0]
829
+ num_partition = max_seq_len // partition_len
830
+ cache_pos_for_partitions = torch.zeros((batch_size, num_partition), dtype=torch.int32)
831
+ if self.phase == "decode":
832
+ for b_idx in range(batch_size):
833
+ cache_pos = current_steps[b_idx]
834
+ for p_idx in range(num_partition):
835
+ cache_pos_for_partitions[b_idx][p_idx] = torch.clamp(
836
+ cache_pos - partition_len * p_idx, 0, partition_len
837
+ )
838
+ else: # prefill
839
+ cache_pos = current_steps[0]
840
+ for p_idx in range(num_partition):
841
+ cache_pos_for_partitions[0][p_idx] = torch.clamp(cache_pos - partition_len * p_idx, 0, partition_len)
842
+
843
+ return cache_pos_for_partitions
844
+
845
+ def rbln_flash_attention(
477
846
  self,
478
- dim,
479
- max_position_embeddings=2048,
480
- base=10000,
481
- device=None,
482
- scaling_factor=1.0,
483
- max_seq_len=2048,
847
+ query_state,
848
+ key_state,
849
+ value_state,
850
+ attn_mask,
851
+ batch_idx,
852
+ past_key_state,
853
+ past_value_state,
854
+ cache_pos_for_partitions,
484
855
  ):
485
- super().__init__(
486
- dim,
487
- max_position_embeddings=max_position_embeddings,
488
- base=base,
489
- scaling_factor=scaling_factor,
856
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
857
+ key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
858
+ value_state = value_state.unsqueeze(2)
859
+ attn_mask = attn_mask.unsqueeze(2)
860
+
861
+ query_state = query_state.view(
862
+ 1,
863
+ self.num_key_value_heads,
864
+ self.num_heads // self.num_key_value_heads,
865
+ -1, # seq len
866
+ self.head_dim,
490
867
  )
491
- # difference to the original RoPE: a scaling factor is aplied to the position ids
492
- if max_seq_len > max_position_embeddings:
493
- positions_ids = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
494
- positions_ids = positions_ids / self.scaling_factor
495
- freqs = torch.outer(positions_ids, self.inv_freq)
496
- emb = torch.cat((freqs, freqs), dim=-1)
497
- cos = emb.cos()
498
- sin = emb.sin()
499
868
 
500
- self._cos_cached = torch.cat([self._cos_cached, cos[max_position_embeddings:]], dim=0)
501
- self._sin_cached = torch.cat([self._sin_cached, sin[max_position_embeddings:]], dim=0)
869
+ # RBLN custom flash attention(decode), dummy batch index
870
+ if self.phase == "decode":
871
+ sidx = cache_pos_for_partitions[batch_idx][0]
872
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_decode(
873
+ query_state,
874
+ key_state,
875
+ value_state,
876
+ attn_mask,
877
+ past_key_state.unsqueeze(2),
878
+ past_value_state.unsqueeze(2),
879
+ sidx,
880
+ self.kvcache_partition_size,
881
+ )
882
+ else:
883
+ sidx = cache_pos_for_partitions[0][0]
884
+ attn_output, key_state, value_state = torch.ops.rbln_custom_ops.flash_attn_prefill(
885
+ query_state,
886
+ key_state,
887
+ value_state,
888
+ attn_mask,
889
+ past_key_state.unsqueeze(2),
890
+ past_value_state.unsqueeze(2),
891
+ batch_idx,
892
+ sidx,
893
+ self.kvcache_partition_size,
894
+ )
502
895
 
896
+ # reshape for removing repeat_kv
897
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
898
+ attn_output = attn_output.transpose(1, 2).contiguous()
899
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
503
900
 
504
- class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
505
- """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
901
+ return attn_output, key_state, value_state
506
902
 
507
- def __init__(
903
+ def forward(
508
904
  self,
509
- dim,
510
- max_position_embeddings=2048,
511
- base=10000,
512
- device=None,
513
- scaling_factor=1.0,
514
- max_seq_len=2048,
905
+ hidden_states: torch.Tensor,
906
+ attention_mask: torch.Tensor,
907
+ current_steps: torch.LongTensor,
908
+ batch_position: torch.Tensor,
909
+ past_key_values: Tuple[Tuple[torch.Tensor]],
910
+ cos: Optional[torch.Tensor] = None,
911
+ sin: Optional[torch.Tensor] = None,
515
912
  ):
516
- super().__init__(
517
- dim,
518
- max_position_embeddings=max_position_embeddings,
519
- base=base,
520
- scaling_factor=scaling_factor,
913
+ batch_size, query_length, _ = hidden_states.size()
914
+
915
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
916
+
917
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
918
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
919
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
920
+ 1, 2
521
921
  )
522
- # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
523
- device = self.inv_freq.device
524
- dtype = self.inv_freq.dtype
525
- if max_seq_len > max_position_embeddings:
526
- position_ids = torch.arange(max_position_embeddings, max_seq_len, dtype=dtype).view(-1, 1)
527
- seq_len = position_ids + 1
528
- base = self.base * (
529
- (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
530
- ) ** (self.dim / (self.dim - 2))
531
-
532
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
533
-
534
- freqs = position_ids * inv_freq
535
- emb = torch.cat((freqs, freqs), dim=-1)
536
- cos = emb.cos()
537
- sin = emb.sin()
538
-
539
- self._cos_cached = torch.cat([self._cos_cached, cos], dim=0)
540
- self._sin_cached = torch.cat([self._sin_cached, sin], dim=0)
922
+ # b, num_head, query, head_dim
923
+
924
+ max_seq_len = past_key_values[self.layer_idx][0].shape[-2]
925
+
926
+ if cos is not None and sin is not None:
927
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
928
+
929
+ cache_pos_for_partitions = self.get_cache_pos_for_partitions(
930
+ current_steps, batch_size=batch_size, max_seq_len=max_seq_len
931
+ ) # batch_size, num_partitions
932
+
933
+ _key_states = []
934
+ _value_states = []
935
+ _attn_outputs = []
936
+ for b in range(batch_size):
937
+ attn_output, key_state, value_state = self.rbln_flash_attention(
938
+ query_states[b].unsqueeze(0),
939
+ key_states[b].unsqueeze(0),
940
+ value_states[b].unsqueeze(0),
941
+ attention_mask[b].unsqueeze(0)
942
+ if self.phase == "decode"
943
+ else attention_mask, # TODO(jongho): fix when msoftmax is supported
944
+ past_key_state=past_key_values[self.layer_idx][0],
945
+ past_value_state=past_key_values[self.layer_idx][1],
946
+ batch_idx=b if self.phase == "decode" else batch_position,
947
+ cache_pos_for_partitions=cache_pos_for_partitions,
948
+ )
949
+ _key_states.append(key_state)
950
+ _value_states.append(value_state)
951
+ _attn_outputs.append(attn_output)
952
+ key_states = torch.cat(_key_states, dim=0)
953
+ value_states = torch.cat(_value_states, dim=0)
954
+ attn_outputs = torch.cat(_attn_outputs, dim=0)
955
+
956
+ attn_outputs = self.o_proj(attn_outputs)
957
+ past_key_values[self.layer_idx] = key_states, value_states
958
+ return attn_outputs, past_key_values