optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (91) hide show
  1. optimum/rbln/__init__.py +4 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -4
  4. optimum/rbln/diffusers/__init__.py +1 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  22. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  23. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  27. optimum/rbln/modeling.py +2 -2
  28. optimum/rbln/modeling_base.py +12 -4
  29. optimum/rbln/ops/attn.py +158 -0
  30. optimum/rbln/ops/flash_attn.py +166 -0
  31. optimum/rbln/transformers/__init__.py +2 -0
  32. optimum/rbln/transformers/configuration_generic.py +4 -4
  33. optimum/rbln/transformers/modeling_generic.py +1 -4
  34. optimum/rbln/transformers/modeling_outputs.py +37 -0
  35. optimum/rbln/transformers/models/__init__.py +6 -16
  36. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  37. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  38. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  39. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  43. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  44. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  45. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  46. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  47. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  49. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
  51. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  52. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  53. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  54. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  55. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +19 -250
  56. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  57. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  58. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  59. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  60. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  61. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  64. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  65. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  66. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  67. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  68. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  69. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  70. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  71. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  72. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  73. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  74. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  75. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  76. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  77. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  78. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  79. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  80. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  81. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  82. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  83. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  84. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  85. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  86. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  87. optimum/rbln/utils/runtime_utils.py +3 -3
  88. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a1.dist-info}/METADATA +1 -1
  89. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a1.dist-info}/RECORD +91 -87
  90. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a1.dist-info}/WHEEL +0 -0
  91. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,450 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import deque
16
+ from typing import Any, Optional
17
+
18
+ import rebel
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ from ....utils.runtime_utils import RBLNPytorchRuntime
23
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
24
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
25
+
26
+
27
+ class RBLNPageTableManager:
28
+ EMPTY_BLOCK = -1
29
+ NO_BLOCKS_ERROR = (
30
+ "No memory blocks are available for allocation. "
31
+ "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
32
+ "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
33
+ "Using vllm-rbln should fix this issue and enhance inference performance."
34
+ )
35
+
36
+ def __init__(self, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
37
+ self.rbln_config = rbln_config
38
+ self.block_tables = torch.zeros(
39
+ self.rbln_config.batch_size,
40
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
41
+ dtype=torch.int16,
42
+ ).fill_(self.EMPTY_BLOCK)
43
+ self.free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
44
+
45
+ def update_block(self, batch_idx: int, block_idx: int):
46
+ """
47
+ If the block is empty (empty_block), allocates a block from the free_block_pool.
48
+ """
49
+ if self.block_tables[batch_idx][block_idx] == self.EMPTY_BLOCK:
50
+ if self.free_block_pool:
51
+ block = self.free_block_pool.popleft()
52
+ self.block_tables[batch_idx][block_idx] = block
53
+ else:
54
+ raise RuntimeError(self.NO_BLOCKS_ERROR)
55
+
56
+ def replace_empty_block(self, block_tables: torch.Tensor):
57
+ """
58
+ Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
59
+ """
60
+ if not torch.any(block_tables == self.EMPTY_BLOCK):
61
+ return block_tables.clone()
62
+ elif self.free_block_pool:
63
+ _free_block = self.free_block_pool[0]
64
+ return torch.where(block_tables == self.EMPTY_BLOCK, _free_block, block_tables)
65
+ else:
66
+ raise RuntimeError(self.NO_BLOCKS_ERROR)
67
+
68
+ def get_block_tables(
69
+ self, cache_position: torch.Tensor, batch_idx: int = None, batch_size: int = None, phase: str = "prefill"
70
+ ) -> torch.Tensor:
71
+ """
72
+ Manages and returns the KV cache block tables.
73
+ Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
74
+
75
+ Args:
76
+ cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
77
+ batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
78
+
79
+ Returns:
80
+ Updated block tables.
81
+ """
82
+
83
+ def get_global_block_tables():
84
+ if not self.rbln_config.use_global_attention:
85
+ return None
86
+
87
+ if phase == "prefill":
88
+ # Track previously used blocks and return them to the free_block_pool and
89
+ # reset the current batch's block table to empty blocks
90
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.EMPTY_BLOCK].tolist()
91
+ self.free_block_pool.extend(prev_blocks)
92
+ self.block_tables[batch_idx].fill_(self.EMPTY_BLOCK)
93
+
94
+ # Get the start (s) and end (e) positions from cache_position and
95
+ # iterate over the cache positions to allocate necessary blocks
96
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
97
+ for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
98
+ block_idx = position // self.rbln_config.kvcache_block_size
99
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
100
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
101
+ self.update_block(batch_idx, block_idx)
102
+
103
+ return self.replace_empty_block(self.block_tables[batch_idx])
104
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
105
+ else:
106
+ for b_idx in range(batch_size):
107
+ position = cache_position[b_idx][0].item()
108
+ block_idx = position // self.rbln_config.kvcache_block_size
109
+ self.update_block(b_idx, block_idx)
110
+
111
+ return self.replace_empty_block(self.block_tables)
112
+
113
+ def get_local_block_tables():
114
+ if not self.rbln_config.use_local_attention:
115
+ return None
116
+ else:
117
+ return (
118
+ torch.tensor([batch_idx], dtype=torch.int16)
119
+ if phase == "prefill"
120
+ else torch.arange(batch_size, dtype=torch.int16).view(batch_size, -1)
121
+ )
122
+
123
+ return get_global_block_tables(), get_local_block_tables()
124
+
125
+ # Whether block_tables and local_block_tables are provided by the user
126
+ def is_external_block_tables(
127
+ self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
128
+ ):
129
+ if self.rbln_config.cache_impl == "static" and block_tables is None:
130
+ return False
131
+ elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
132
+ return False
133
+ elif self.rbln_config.cache_impl == "hybrid":
134
+ if (block_tables is not None) != (local_block_tables is not None):
135
+ raise ValueError(
136
+ "Both block_tables and local_block_tables must be provided or neither of them must be provided."
137
+ )
138
+ elif block_tables is None and local_block_tables is None:
139
+ return False
140
+
141
+ return True
142
+
143
+ def get_block_tables_if_needed(
144
+ self,
145
+ batch_size,
146
+ cache_position: torch.Tensor,
147
+ batch_idx: int = None,
148
+ phase: str = "prefill",
149
+ block_tables: Optional[torch.Tensor] = None,
150
+ local_block_tables: Optional[torch.Tensor] = None,
151
+ ):
152
+ is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
153
+ if not is_external_block_tables:
154
+ block_tables, local_block_tables = self.get_block_tables(
155
+ cache_position, batch_idx=batch_idx, batch_size=batch_size, phase=phase
156
+ )
157
+
158
+ return block_tables, local_block_tables, is_external_block_tables
159
+
160
+
161
+ class RBLNRuntimeModel(RBLNPytorchRuntime):
162
+ mandatory_members = ["main_input_name", "embed_tokens"]
163
+
164
+ def __init__(
165
+ self,
166
+ runtime: rebel.Runtime,
167
+ phase: str,
168
+ batch_size: int,
169
+ dec_attn_mask: torch.Tensor,
170
+ page_table_manager: RBLNPageTableManager,
171
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
172
+ out_buffers: Optional[torch.Tensor] = None,
173
+ **kwargs: Any,
174
+ ) -> None:
175
+ super().__init__(runtime, **kwargs)
176
+ self.phase = phase
177
+ self.batch_size = batch_size
178
+ self.rbln_config = rbln_config
179
+
180
+ # shared resources between prefill and decode phase
181
+ self.dec_attn_mask = dec_attn_mask
182
+ self.page_table_manager = page_table_manager
183
+
184
+ if self.phase == "prefill":
185
+ self.out_buffers = out_buffers
186
+ self.causal_mask = 1 - torch.triu(
187
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
188
+ )
189
+
190
+ def inputs_embeddings_if_needed(
191
+ self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
192
+ ):
193
+ if input_ids is None and inputs_embeds is None:
194
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
195
+
196
+ if self.rbln_config.use_inputs_embeds:
197
+ return self.embed_tokens(input_ids) if inputs_embeds is None else inputs_embeds
198
+ else:
199
+ return input_ids
200
+
201
+ def forward(
202
+ self,
203
+ input_ids: Optional[torch.LongTensor] = None,
204
+ inputs_embeds: Optional[torch.Tensor] = None,
205
+ cache_position: torch.Tensor = None,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ batch_idx: Optional[int] = None,
208
+ block_tables: Optional[torch.Tensor] = None,
209
+ position_embed: Optional[torch.Tensor] = None,
210
+ position_ids: Optional[torch.Tensor] = None,
211
+ token_type_ids: Optional[torch.Tensor] = None,
212
+ local_block_tables: Optional[torch.Tensor] = None,
213
+ ):
214
+ inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
215
+ block_tables, local_block_tables, is_external_block_tables = (
216
+ self.page_table_manager.get_block_tables_if_needed(
217
+ self.batch_size,
218
+ cache_position,
219
+ batch_idx=batch_idx,
220
+ phase=self.phase,
221
+ block_tables=block_tables,
222
+ local_block_tables=local_block_tables,
223
+ )
224
+ )
225
+
226
+ if self.phase == "decode":
227
+ return self.decode_forward(
228
+ inputs,
229
+ cache_position,
230
+ block_tables,
231
+ is_external_block_tables,
232
+ attention_mask=attention_mask,
233
+ position_embed=position_embed,
234
+ position_ids=position_ids,
235
+ local_block_tables=local_block_tables,
236
+ )
237
+ else:
238
+ return self.prefill_forward(
239
+ inputs,
240
+ cache_position,
241
+ attention_mask,
242
+ batch_idx,
243
+ block_tables,
244
+ is_external_block_tables=is_external_block_tables,
245
+ position_embed=position_embed,
246
+ token_type_ids=token_type_ids,
247
+ local_block_tables=local_block_tables,
248
+ )
249
+
250
+ def decode_forward(
251
+ self,
252
+ inputs: torch.Tensor,
253
+ cache_position: torch.Tensor = None,
254
+ block_tables: torch.Tensor = None,
255
+ is_external_block_tables: bool = None,
256
+ attention_mask: Optional[torch.Tensor] = None,
257
+ position_embed: Optional[torch.Tensor] = None,
258
+ position_ids: Optional[torch.Tensor] = None,
259
+ local_block_tables: Optional[torch.Tensor] = None,
260
+ ) -> torch.FloatTensor:
261
+ if self.batch_size != cache_position.shape[0]:
262
+ raise RuntimeError(
263
+ f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
264
+ )
265
+
266
+ if self.rbln_config.use_attention_mask and attention_mask is None:
267
+ for b_idx in range(self.batch_size):
268
+ decoding_step = cache_position[b_idx].item()
269
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
270
+ raise ValueError(
271
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
272
+ )
273
+
274
+ if is_external_block_tables:
275
+ self.dec_attn_mask[b_idx].fill_(0)
276
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
277
+ else:
278
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
279
+
280
+ attention_mask = self.dec_attn_mask
281
+
282
+ logits = super().forward(
283
+ inputs,
284
+ cache_position,
285
+ block_tables,
286
+ local_block_tables,
287
+ position_embed,
288
+ attention_mask if self.rbln_config.use_attention_mask else None,
289
+ position_ids if self.rbln_config.use_position_ids else None,
290
+ )
291
+
292
+ return RBLNDecoderOnlyOutput(logits=logits)
293
+
294
+ def _prepare_prefill_inputs(
295
+ self,
296
+ inputs: torch.Tensor,
297
+ cache_position: Optional[torch.Tensor] = None,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ position_embed: Optional[torch.Tensor] = None,
300
+ token_type_ids: Optional[torch.Tensor] = None,
301
+ ):
302
+ """
303
+ Prepare inputs for prefill phase.
304
+ """
305
+ # Handle continuous batching in a compiled graph by extracting valid inputs
306
+ # If an attention mask is provided, select only the valid (non-masked) inputs
307
+ if attention_mask is not None:
308
+ inputs = inputs[:, attention_mask.bool()]
309
+ position_embed = None if position_embed is None else position_embed[:, :, :, attention_mask.bool(), :]
310
+ token_type_ids = None if token_type_ids is None else token_type_ids[:, attention_mask.bool()]
311
+
312
+ query_length = inputs.shape[1]
313
+ if query_length > self.rbln_config.max_seq_len:
314
+ raise ValueError(
315
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
316
+ )
317
+
318
+ # Initialize attention mask for chunked processing
319
+ chunked_attention_mask = (
320
+ torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
321
+ if self.rbln_config.use_attention_mask
322
+ else None
323
+ )
324
+
325
+ cache_position = (
326
+ torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
327
+ )
328
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
329
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
330
+ if padding_size > 0:
331
+ inputs = (
332
+ F.pad(inputs, (0, 0, 0, padding_size))
333
+ if self.rbln_config.use_inputs_embeds
334
+ else F.pad(inputs, (0, padding_size))
335
+ )
336
+ position_embed = F.pad(position_embed, (0, 0, 0, padding_size)) if position_embed is not None else None
337
+ token_type_ids = F.pad(token_type_ids, (0, padding_size), value=-1) if token_type_ids is not None else None
338
+ cache_position = F.pad(cache_position, (0, padding_size))
339
+
340
+ # Overwrite position_ids and padded_cache_lengths
341
+ position_ids = cache_position.clone() if self.rbln_config.use_position_ids else None
342
+ padded_cache_lengths = 0
343
+
344
+ return (
345
+ inputs,
346
+ cache_position,
347
+ chunked_attention_mask,
348
+ position_ids,
349
+ position_embed,
350
+ padded_cache_lengths,
351
+ query_length,
352
+ token_type_ids,
353
+ )
354
+
355
+ def prefill_forward(
356
+ self,
357
+ inputs: torch.Tensor,
358
+ cache_position: Optional[torch.Tensor] = None,
359
+ attention_mask: Optional[torch.Tensor] = None,
360
+ batch_idx: Optional[int] = None,
361
+ block_tables: Optional[torch.Tensor] = None,
362
+ is_external_block_tables: Optional[bool] = None,
363
+ position_embed: Optional[torch.Tensor] = None,
364
+ token_type_ids: Optional[torch.Tensor] = None,
365
+ local_block_tables: Optional[torch.Tensor] = None,
366
+ ) -> torch.FloatTensor:
367
+ """
368
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
369
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
370
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
371
+ """
372
+ (
373
+ inputs,
374
+ cache_position,
375
+ chunked_attention_mask,
376
+ position_ids,
377
+ position_embed,
378
+ padded_cache_lengths,
379
+ query_length,
380
+ token_type_ids,
381
+ ) = self._prepare_prefill_inputs(
382
+ inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
383
+ )
384
+
385
+ # Process input in chunks of size `prefill_chunk_size`
386
+ output_logits = []
387
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
388
+ s, e = step, step + self.rbln_config.prefill_chunk_size
389
+ # Extract the current chunk of inputs, cache positions, position ids, and position embeddings
390
+ input_chunk = inputs[:, s:e]
391
+ cache_pos_chunk = cache_position[:, s:e]
392
+ position_ids_chunk = position_ids[:, s:e] if self.rbln_config.use_position_ids else None
393
+ position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
394
+
395
+ # Update attention mask to ensure proper causal behavior
396
+ if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
397
+ if step > 0: # update previous chunk
398
+ chunked_attention_mask[
399
+ :, :, :, s - self.rbln_config.prefill_chunk_size : e - self.rbln_config.prefill_chunk_size
400
+ ] = 1
401
+ chunked_attention_mask[:, :, :, s:e] = self.causal_mask
402
+
403
+ # Calculate query position if needed
404
+ if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
405
+ query_position = (
406
+ torch.tensor((query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16)
407
+ if e >= query_length
408
+ else torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
409
+ )
410
+ else:
411
+ query_position = None
412
+
413
+ # Forward pass for the current chunk
414
+ output_logit = super().forward(
415
+ input_chunk,
416
+ cache_pos_chunk,
417
+ block_tables,
418
+ local_block_tables,
419
+ position_embed_chunk,
420
+ query_position,
421
+ chunked_attention_mask if self.rbln_config.use_attention_mask else None,
422
+ position_ids_chunk,
423
+ out=self.out_buffers,
424
+ )
425
+ output_logits.append(output_logit)
426
+
427
+ # Aggregate output_logits
428
+ output_logits = torch.concat(output_logits, dim=-2)
429
+ if self.rbln_config.logits_to_keep > 0:
430
+ output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
431
+ else:
432
+ output_logits = output_logits[:, :query_length, :]
433
+ # index copy for masked output_logits
434
+ if attention_mask is not None:
435
+ new_output_logits = torch.full(
436
+ (1, attention_mask.shape[-1], output_logits.shape[-1]),
437
+ fill_value=1e-10,
438
+ dtype=output_logits.dtype,
439
+ )
440
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
441
+ new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
442
+
443
+ output_logits = new_output_logits
444
+
445
+ # Update decoder attention mask with processed KV-cache length from prefill phase
446
+ if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
447
+ self.dec_attn_mask[batch_idx].fill_(0)
448
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
449
+
450
+ return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
@@ -0,0 +1,88 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, Optional
16
+
17
+ import torch
18
+ from transformers.generation.utils import GenerationMixin
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
23
+
24
+
25
+ class RBLNDecoderOnlyGenerationMixin(GenerationMixin):
26
+ _supports_cache_class = False # Needed for GenerationMixin
27
+ _is_stateful = False # Needed for GenerationMixin
28
+
29
+ def _reorder_cache(self, past_key_values, beam_idx):
30
+ raise NotImplementedError
31
+
32
+ def prepare_inputs_for_generation(
33
+ self,
34
+ input_ids: torch.LongTensor,
35
+ generate_idx: Optional[torch.Tensor] = None,
36
+ attention_mask: Optional[torch.LongTensor] = None,
37
+ inputs_embeds: Optional[torch.Tensor] = None,
38
+ padded_cache_lengths: Optional[torch.Tensor] = None,
39
+ **kwargs,
40
+ ):
41
+ model_inputs = {}
42
+ is_prefill_phase = generate_idx is None
43
+
44
+ if is_prefill_phase:
45
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
46
+ padded_cache_lengths = torch.zeros_like(generate_idx)
47
+ cache_position = None
48
+ position_ids = None
49
+ else:
50
+ if inputs_embeds is not None:
51
+ # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
52
+ inputs_embeds = None
53
+
54
+ input_ids = input_ids[:, -1:]
55
+ position_ids = generate_idx
56
+ cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
57
+ generate_idx = generate_idx + 1
58
+ model_inputs.update({"input_ids": input_ids})
59
+
60
+ if inputs_embeds is not None:
61
+ if self.rbln_config.use_inputs_embeds:
62
+ model_inputs.update({"inputs_embeds": inputs_embeds})
63
+ else:
64
+ raise ValueError(
65
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
66
+ )
67
+ else:
68
+ model_inputs.update({"input_ids": input_ids})
69
+
70
+ model_inputs.update(
71
+ {
72
+ "attention_mask": attention_mask,
73
+ "cache_position": cache_position,
74
+ "generate_idx": generate_idx,
75
+ "position_ids": position_ids,
76
+ "padded_cache_lengths": padded_cache_lengths,
77
+ }
78
+ )
79
+
80
+ return model_inputs
81
+
82
+ def _update_model_kwargs_for_generation(
83
+ self, outputs: "RBLNDecoderOnlyOutput", model_kwargs: Dict[str, Any], **kwargs
84
+ ) -> Dict[str, Any]:
85
+ # update generate_idx
86
+ model_kwargs["generate_idx"] = outputs.generate_idx
87
+ model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
88
+ return model_kwargs