optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (105) hide show
  1. optimum/rbln/__init__.py +32 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +20 -4
  4. optimum/rbln/diffusers/__init__.py +7 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  22. optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  24. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  27. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  28. optimum/rbln/modeling.py +3 -2
  29. optimum/rbln/modeling_base.py +29 -4
  30. optimum/rbln/ops/attn.py +158 -0
  31. optimum/rbln/ops/flash_attn.py +166 -0
  32. optimum/rbln/transformers/__init__.py +24 -0
  33. optimum/rbln/transformers/configuration_generic.py +6 -4
  34. optimum/rbln/transformers/modeling_generic.py +13 -8
  35. optimum/rbln/transformers/modeling_outputs.py +37 -0
  36. optimum/rbln/transformers/models/__init__.py +31 -16
  37. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  40. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  41. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  43. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  44. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
  45. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  46. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  47. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  48. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  52. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  53. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
  54. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  55. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  56. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  57. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  58. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  59. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  60. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  61. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +25 -251
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  63. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  64. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  67. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  68. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  69. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  75. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  76. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  77. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  78. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  79. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  80. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  81. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  82. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  83. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  85. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  86. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  87. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  88. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  89. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  90. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  91. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  92. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  94. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  95. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  96. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  97. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  99. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  100. optimum/rbln/utils/runtime_utils.py +3 -3
  101. optimum/rbln/utils/submodule.py +10 -4
  102. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
  103. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
  104. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
  105. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
@@ -13,10 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from collections import deque
17
- from dataclasses import dataclass
18
16
  from pathlib import Path
19
- from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
20
18
 
21
19
  import rebel
22
20
  import torch
@@ -24,21 +22,22 @@ from rebel.compile_context import CompileContext
24
22
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
25
23
  from transformers.modeling_outputs import BaseModelOutputWithPast
26
24
  from transformers.modeling_utils import no_init_weights
27
- from transformers.utils import ModelOutput
28
25
 
29
26
  from ....configuration_utils import RBLNCompileConfig
30
27
  from ....modeling import RBLNModel
31
28
  from ....utils.logging import get_logger
32
- from ....utils.runtime_utils import RBLNPytorchRuntime
33
29
  from ...modeling_attention_utils import (
34
30
  RBLNDecoderOnlyFlashAttentionMixin,
35
31
  set_default_values,
36
32
  validate_attention_method,
37
33
  validate_sliding_window,
38
34
  )
35
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
39
36
  from ...utils.rbln_quantization import prepare_model_for_quantization
40
37
  from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
41
38
  from .decoderonly_architecture import DecoderOnlyWrapper
39
+ from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
40
+ from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
42
41
 
43
42
 
44
43
  logger = get_logger()
@@ -47,419 +46,6 @@ if TYPE_CHECKING:
47
46
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
48
47
 
49
48
 
50
- class RBLNRuntimeModel(RBLNPytorchRuntime):
51
- mandatory_members = ["main_input_name", "embed_tokens"]
52
-
53
- def __init__(
54
- self,
55
- runtime: rebel.Runtime,
56
- phase: str,
57
- batch_size: int,
58
- dec_attn_mask: torch.Tensor,
59
- block_tables: torch.Tensor,
60
- free_block_pool: Deque,
61
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
62
- **kwargs: Any,
63
- ) -> None:
64
- super().__init__(runtime, **kwargs)
65
- self.phase = phase
66
- self.batch_size = batch_size
67
- self.rbln_config = rbln_config
68
-
69
- # shared tensor between prefill and decode phase
70
- self.dec_attn_mask = dec_attn_mask
71
- self.block_tables = block_tables
72
- self.free_block_pool = free_block_pool
73
-
74
- self.empty_block = -1
75
- if self.phase == "prefill":
76
- vocab_size = kwargs.pop("vocab_size")
77
- self.output_size = [1, 1, vocab_size]
78
- self.causal_mask = 1 - torch.triu(
79
- torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
80
- )
81
-
82
- def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
83
- """
84
- Manages and returns the KV cache block tables.
85
- Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
86
-
87
- Args:
88
- cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
89
- batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
90
-
91
- Returns:
92
- Updated block tables.
93
- """
94
-
95
- NO_BLOCKS_ERROR = (
96
- "No memory blocks are available for allocation. "
97
- "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
98
- "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
99
- "Using vllm-rbln should fix this issue and enhance inference performance."
100
- )
101
-
102
- def update_block(batch_idx: int, block_idx: int):
103
- """
104
- If the block is empty (empty_block), allocates a block from the free_block_pool.
105
- """
106
- if self.block_tables[batch_idx][block_idx] == self.empty_block:
107
- if self.free_block_pool:
108
- block = self.free_block_pool.popleft()
109
- self.block_tables[batch_idx][block_idx] = block
110
- else:
111
- raise RuntimeError(NO_BLOCKS_ERROR)
112
-
113
- def replace_empty_block(block_tables: torch.Tensor):
114
- """
115
- Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
116
- """
117
- if not torch.any(block_tables == self.empty_block):
118
- return block_tables.clone()
119
- elif self.free_block_pool:
120
- _free_block = self.free_block_pool[0]
121
- return torch.where(block_tables == self.empty_block, _free_block, block_tables)
122
- else:
123
- raise RuntimeError(NO_BLOCKS_ERROR)
124
-
125
- def get_global_block_tables(batch_idx: int):
126
- if self.rbln_config.cache_impl == "sliding_window":
127
- return None
128
-
129
- if self.phase == "prefill":
130
- # Track previously used blocks and return them to the free_block_pool and
131
- # reset the current batch's block table to empty blocks
132
- prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
133
- self.free_block_pool.extend(prev_blocks)
134
- self.block_tables[batch_idx].fill_(self.empty_block)
135
-
136
- # Get the start (s) and end (e) positions from cache_position and
137
- # iterate over the cache positions to allocate necessary blocks
138
- s, e = cache_position[0][0].item(), cache_position[0][-1].item()
139
- for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
140
- block_idx = position // self.rbln_config.kvcache_block_size
141
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
142
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
143
- update_block(batch_idx, block_idx)
144
-
145
- return replace_empty_block(self.block_tables[batch_idx])
146
- # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
147
- else:
148
- for b_idx in range(self.batch_size):
149
- position = cache_position[b_idx][0].item()
150
- block_idx = position // self.rbln_config.kvcache_block_size
151
- update_block(b_idx, block_idx)
152
-
153
- return replace_empty_block(self.block_tables)
154
-
155
- def get_local_block_tables(batch_idx: int):
156
- if self.rbln_config.cache_impl == "static":
157
- return None
158
- else:
159
- return (
160
- torch.tensor([batch_idx], dtype=torch.int16)
161
- if self.phase == "prefill"
162
- else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
163
- )
164
-
165
- return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
166
-
167
- def is_external_block_tables(
168
- self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
169
- ):
170
- if self.rbln_config.cache_impl == "static" and block_tables is None:
171
- return False
172
- elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
173
- return False
174
- elif self.rbln_config.cache_impl == "hybrid":
175
- if (block_tables is not None) != (local_block_tables is not None):
176
- raise ValueError(
177
- "Both block_tables and local_block_tables must be provided or neither of them must be provided."
178
- )
179
- elif block_tables is None and local_block_tables is None:
180
- return False
181
-
182
- return True
183
-
184
- def forward(
185
- self,
186
- input_ids: Optional[torch.LongTensor] = None,
187
- inputs_embeds: Optional[torch.Tensor] = None,
188
- cache_position: torch.Tensor = None,
189
- attention_mask: Optional[torch.Tensor] = None,
190
- batch_idx: Optional[int] = None,
191
- block_tables: Optional[torch.Tensor] = None,
192
- position_embed: Optional[torch.Tensor] = None,
193
- position_ids: Optional[torch.Tensor] = None,
194
- token_type_ids: Optional[torch.Tensor] = None,
195
- local_block_tables: Optional[torch.Tensor] = None,
196
- ):
197
- if input_ids is None and inputs_embeds is None:
198
- raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
199
-
200
- if inputs_embeds is None:
201
- inputs = input_ids
202
- if self.embed_tokens is not None:
203
- inputs = self.embed_tokens(inputs)
204
- else:
205
- inputs = inputs_embeds
206
-
207
- is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
208
- if not is_external_block_tables:
209
- block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
210
-
211
- if self.phase == "decode":
212
- return self.decode_forward(
213
- inputs,
214
- cache_position,
215
- block_tables,
216
- is_external_block_tables,
217
- attention_mask=attention_mask,
218
- position_embed=position_embed,
219
- position_ids=position_ids,
220
- local_block_tables=local_block_tables,
221
- )
222
- else:
223
- return self.prefill_forward(
224
- inputs,
225
- cache_position,
226
- attention_mask,
227
- batch_idx,
228
- block_tables,
229
- is_external_block_tables=is_external_block_tables,
230
- position_embed=position_embed,
231
- token_type_ids=token_type_ids,
232
- local_block_tables=local_block_tables,
233
- )
234
-
235
- def decode_forward(
236
- self,
237
- inputs: torch.Tensor,
238
- cache_position: torch.Tensor = None,
239
- block_tables: torch.Tensor = None,
240
- is_external_block_tables: bool = None,
241
- attention_mask: Optional[torch.Tensor] = None,
242
- position_embed: Optional[torch.Tensor] = None,
243
- position_ids: Optional[torch.Tensor] = None,
244
- local_block_tables: Optional[torch.Tensor] = None,
245
- ) -> torch.FloatTensor:
246
- batch_size = inputs.shape[0]
247
- if batch_size != self.batch_size:
248
- raise RuntimeError(
249
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
250
- )
251
-
252
- if batch_size != cache_position.shape[0]:
253
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
254
-
255
- if self.rbln_config.use_attention_mask and attention_mask is None:
256
- for b_idx in range(batch_size):
257
- decoding_step = cache_position[b_idx].item()
258
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
259
- raise ValueError(
260
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
261
- )
262
-
263
- if is_external_block_tables:
264
- self.dec_attn_mask[b_idx].fill_(0)
265
- self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
266
- else:
267
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
268
-
269
- attention_mask = self.dec_attn_mask
270
-
271
- if self.rbln_config.use_global_attention and self.batch_size < block_tables.shape[0]:
272
- block_tables = block_tables[: self.batch_size]
273
-
274
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
275
- attention_mask = attention_mask[: self.batch_size]
276
-
277
- logits = super().forward(
278
- inputs,
279
- cache_position,
280
- block_tables,
281
- local_block_tables,
282
- position_embed,
283
- attention_mask if self.rbln_config.use_attention_mask else None,
284
- position_ids if self.rbln_config.use_position_ids else None,
285
- )
286
-
287
- return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
288
-
289
- def _prepare_prefill_inputs(
290
- self,
291
- inputs: torch.Tensor,
292
- cache_position: torch.Tensor,
293
- attention_mask: Optional[torch.Tensor] = None,
294
- position_embed: Optional[torch.Tensor] = None,
295
- token_type_ids: Optional[torch.Tensor] = None,
296
- ):
297
- """
298
- Prepare inputs for prefill phase.
299
- """
300
- # Handle continuous batching in a compiled graph by extracting valid inputs
301
- # If an attention mask is provided, select only the valid (non-masked) inputs
302
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
303
- if position_embed is not None:
304
- position_embed = (
305
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
306
- )
307
- if token_type_ids is not None:
308
- token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
309
-
310
- query_length = inputs.shape[1]
311
- if query_length > self.rbln_config.max_seq_len:
312
- raise ValueError(
313
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
314
- )
315
-
316
- # Initialize attention mask for chunked processing
317
- chunked_attention_mask = (
318
- torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
319
- if self.rbln_config.use_attention_mask
320
- else None
321
- )
322
-
323
- # Buffer for storing output logits
324
- out_buffers = [
325
- torch.empty(
326
- size=self.output_size,
327
- dtype=torch.float32,
328
- device="cpu",
329
- )
330
- ]
331
-
332
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
333
- padding_size = 0
334
- if query_length % self.rbln_config.prefill_chunk_size != 0:
335
- padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
336
- # inputs_embeds
337
- if inputs.dim() == 3:
338
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
339
- # inputs_ids
340
- else:
341
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
342
-
343
- cache_position = torch.cat(
344
- [
345
- cache_position,
346
- torch.arange(
347
- query_length,
348
- query_length + padding_size,
349
- dtype=torch.int32,
350
- ).unsqueeze(0),
351
- ],
352
- dim=-1,
353
- )
354
-
355
- if position_embed is not None:
356
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
357
-
358
- if token_type_ids is not None:
359
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
360
-
361
- # Overwrite position_ids and padded_cache_lengths
362
- position_ids = cache_position.clone()
363
- padded_cache_lengths = 0
364
-
365
- return (
366
- inputs,
367
- cache_position,
368
- chunked_attention_mask,
369
- out_buffers,
370
- position_ids,
371
- position_embed,
372
- padded_cache_lengths,
373
- query_length,
374
- token_type_ids,
375
- )
376
-
377
- def prefill_forward(
378
- self,
379
- inputs: torch.Tensor,
380
- cache_position: torch.Tensor = None,
381
- attention_mask: Optional[torch.Tensor] = None,
382
- batch_idx: int = None,
383
- block_tables: torch.Tensor = None,
384
- is_external_block_tables: bool = False,
385
- position_embed: Optional[torch.Tensor] = None,
386
- token_type_ids: Optional[torch.Tensor] = None,
387
- local_block_tables: Optional[torch.Tensor] = None,
388
- ) -> torch.FloatTensor:
389
- """
390
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
391
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
392
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
393
- """
394
- (
395
- inputs,
396
- cache_position,
397
- chunked_attention_mask,
398
- out_buffers,
399
- position_ids,
400
- position_embed,
401
- padded_cache_lengths,
402
- query_length,
403
- token_type_ids,
404
- ) = self._prepare_prefill_inputs(
405
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
406
- )
407
-
408
- # Process input in chunks of size `prefill_chunk_size`
409
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
410
- # Extract the current chunk of inputs and cache positions
411
- input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
412
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
413
- position_ids_chunk = (
414
- position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
415
- if position_ids is not None
416
- else None
417
- )
418
- if position_embed is not None:
419
- position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
420
-
421
- if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
422
- # Update attention mask to ensure proper causal behavior
423
- if step >= self.rbln_config.prefill_chunk_size:
424
- chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
425
- chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
426
-
427
- # Define query position
428
- if step + self.rbln_config.prefill_chunk_size >= query_length:
429
- query_position = torch.tensor(
430
- (query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
431
- )
432
- else:
433
- query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
434
-
435
- # Forward pass for the current chunk
436
- logits = super().forward(
437
- input_chunk,
438
- cache_pos_chunk,
439
- block_tables,
440
- local_block_tables,
441
- position_embed_chunk if position_embed is not None else None,
442
- query_position,
443
- chunked_attention_mask if self.rbln_config.use_attention_mask else None,
444
- position_ids_chunk if self.rbln_config.use_position_ids else None,
445
- out=out_buffers,
446
- )
447
-
448
- # Update decoder attention mask with processed KV-cache length from prefill phase
449
- if not is_external_block_tables and self.rbln_config.use_attention_mask:
450
- self.dec_attn_mask[batch_idx].fill_(0)
451
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
452
-
453
- return RBLNDecoderOnlyForCausalLMOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
454
-
455
-
456
- @dataclass
457
- class RBLNDecoderOnlyForCausalLMOutput(ModelOutput):
458
- logits: torch.FloatTensor = None
459
- generate_idx: torch.Tensor = None
460
- padded_cache_lengths: int = None
461
-
462
-
463
49
  class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
464
50
  """
465
51
  A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
@@ -495,18 +81,116 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
495
81
  else:
496
82
  self.embed_tokens = None
497
83
 
498
- # TODO: add prefill runtime class.
499
- self.prefill_decoder = RBLNPytorchRuntime(runtime=self.model[0])
500
-
501
- # attributes for prefill
502
- if self.rbln_config.use_global_attention:
503
- self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
504
- if self.rbln_config.use_local_attention:
505
- self.local_block_tables = torch.tensor([0], dtype=torch.int16)
506
- if self.rbln_config.use_attention_mask:
507
- self.causal_mask = 1 - torch.triu(
508
- torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
84
+ self.setup_runtime()
85
+
86
+ def setup_runtime(self):
87
+ # Initialize resources to be used across Runtime instances (prefill and decode phases)
88
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
89
+ dec_attn_mask = torch.zeros(
90
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
91
+ )
92
+ out_buffers = [torch.empty(self.prefill_output_size, dtype=torch.float32, device="cpu")]
93
+
94
+ common_kwargs = {
95
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
96
+ "embed_tokens": self.embed_tokens,
97
+ "dec_attn_mask": dec_attn_mask,
98
+ "page_table_manager": page_table_manager,
99
+ "rbln_config": self.rbln_config,
100
+ }
101
+ self.prefill_decoder = RBLNRuntimeModel(
102
+ runtime=self.model[0],
103
+ phase="prefill",
104
+ batch_size=self.rbln_config.batch_size,
105
+ out_buffers=out_buffers,
106
+ **common_kwargs,
107
+ )
108
+ if self.can_generate():
109
+ self.decoders = {}
110
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
111
+ self.decoders[batch_size] = RBLNRuntimeModel(
112
+ runtime=self.model[i + 1],
113
+ phase="decode",
114
+ batch_size=batch_size,
115
+ **common_kwargs,
116
+ )
117
+
118
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
119
+ self.decoder = self.decoders[self.rbln_config.batch_size]
120
+
121
+ @property
122
+ def prefill_output_size(self):
123
+ return (
124
+ 1,
125
+ self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
126
+ self.config.hidden_size,
127
+ )
128
+
129
+ @classmethod
130
+ def get_quantized_model(
131
+ cls,
132
+ model_id: str,
133
+ config: Optional[PretrainedConfig] = None,
134
+ use_auth_token: Optional[Union[bool, str]] = None,
135
+ revision: Optional[str] = None,
136
+ force_download: bool = False,
137
+ cache_dir: Optional[str] = None,
138
+ subfolder: str = "",
139
+ local_files_only: bool = False,
140
+ trust_remote_code: bool = False,
141
+ rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
142
+ **kwargs,
143
+ ):
144
+ kwargs = cls.update_kwargs(kwargs)
145
+
146
+ if config is None:
147
+ config = AutoConfig.from_pretrained(
148
+ model_id,
149
+ use_auth_token=use_auth_token,
150
+ revision=revision,
151
+ force_download=force_download,
152
+ cache_dir=cache_dir,
153
+ trust_remote_code=trust_remote_code,
154
+ **kwargs,
509
155
  )
156
+ if config.torch_dtype == torch.bfloat16:
157
+ # FIXME: bfloat16 is not supported by rebel-compiler
158
+ config.torch_dtype = torch.float32
159
+
160
+ with no_init_weights():
161
+ model = cls.auto_model_class.from_config(config)
162
+
163
+ model = prepare_model_for_quantization(
164
+ model,
165
+ model_id,
166
+ kwargs.get("num_hidden_layers"),
167
+ use_auth_token=use_auth_token,
168
+ revision=revision,
169
+ cache_dir=cache_dir,
170
+ force_download=force_download,
171
+ local_files_only=local_files_only,
172
+ rbln_quantization=rbln_config.quantization,
173
+ )
174
+ return model
175
+
176
+ def __getattr__(self, __name: str) -> Any:
177
+ # Special method to delegate attribute access to the original Huggingface LM class.
178
+ # This method is called when an attribute is not found in the current instance's dictionary.
179
+ # It enables transparent access to the original model's attributes and methods while maintaining
180
+ # proper method binding.
181
+
182
+ # The method implements a delegation pattern that:
183
+
184
+ # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
185
+ # 2. For other attributes: Returns them directly from the original class
186
+
187
+ def redirect(func):
188
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
189
+
190
+ val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
191
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
192
+ return redirect(val)
193
+ return val
510
194
 
511
195
  @classmethod
512
196
  def save_torch_artifacts(
@@ -532,6 +216,14 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
532
216
  )
533
217
  return embed_tokens
534
218
 
219
+ def get_decoder(self):
220
+ if not self.can_generate():
221
+ raise ValueError("Decode stage is not supported in this model.")
222
+ return self.decoder
223
+
224
+ def can_generate(self):
225
+ return self.rbln_config.can_generate
226
+
535
227
  def get_input_embeddings(self):
536
228
  return self.embed_tokens
537
229
 
@@ -543,20 +235,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
543
235
 
544
236
  @classmethod
545
237
  def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
546
- wrapper_cfg = {
547
- "max_seq_len": rbln_config.max_seq_len,
548
- "attn_impl": rbln_config.attn_impl,
549
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
550
- "kvcache_block_size": rbln_config.kvcache_block_size,
551
- "use_rotary_emb": cls._use_rotary_emb,
552
- "use_attention_mask": rbln_config.use_attention_mask,
553
- "use_position_ids": rbln_config.use_position_ids,
554
- "use_inputs_embeds": rbln_config.use_inputs_embeds,
555
- "cache_impl": rbln_config.cache_impl,
556
- "sliding_window": rbln_config.sliding_window,
557
- "sliding_window_layers": rbln_config.sliding_window_layers,
558
- }
559
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
238
+ return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
560
239
 
561
240
  @classmethod
562
241
  def _compile_model(
@@ -608,38 +287,58 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
608
287
 
609
288
  @classmethod
610
289
  @torch.inference_mode()
611
- def get_compiled_model(
612
- cls,
613
- model: PreTrainedModel,
614
- rbln_config: RBLNDecoderOnlyModelConfig,
615
- ):
290
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
616
291
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
617
- compile_config = rbln_config.compile_cfgs[0]
292
+ prefill_compile_config = rbln_config.compile_cfgs[0]
618
293
 
619
294
  # Here we use meta tensor, for the memory efficiency.
620
- meta_tensor_names = [name for name, _, _ in compile_config.input_info if "past_key_values" in name]
621
- example_inputs = compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
622
- context, _ = cls._get_compile_context(compile_config, example_inputs)
295
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
296
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
297
+ context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
623
298
 
624
- compiled_model = cls._compile_model(
625
- wrapped_model, compile_config, example_inputs, context, rbln_config, rbln_config.quantization, "prefill"
299
+ compiled_models = {}
300
+ compiled_models["prefill"] = cls._compile_model(
301
+ wrapped_model,
302
+ prefill_compile_config,
303
+ prefill_example_inputs,
304
+ context,
305
+ rbln_config,
306
+ rbln_config.quantization,
307
+ phase="prefill",
626
308
  )
627
- compiled_models = {"prefill": compiled_model}
628
309
 
629
- return compiled_models
310
+ if rbln_config.can_generate:
311
+ wrapped_model.phase = "decode"
312
+ for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
313
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
314
+ compiled_decoder = cls._compile_model(
315
+ wrapped_model,
316
+ dec_compile_config,
317
+ dec_example_inputs,
318
+ context,
319
+ rbln_config,
320
+ rbln_config.quantization,
321
+ phase="decode",
322
+ )
323
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
630
324
 
631
- @classmethod
632
- def get_quantized_model(
633
- cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
634
- ) -> PreTrainedModel:
635
- raise NotImplementedError
325
+ # check if the memory is enough to have additional blocks
326
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
327
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
328
+ cls.maybe_suggest_kvcache_num_blocks(
329
+ compiled_models=compiled_models,
330
+ model_config=model.config,
331
+ rbln_config=rbln_config,
332
+ )
333
+
334
+ return compiled_models
636
335
 
637
336
  @classmethod
638
337
  def get_pytorch_model(
639
338
  cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
640
339
  ) -> PreTrainedModel:
641
340
  if rbln_config and rbln_config.quantization:
642
- model = cls.get_quantized_model(*args, **kwargs)
341
+ model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
643
342
  else:
644
343
  model = super().get_pytorch_model(*args, **kwargs)
645
344
 
@@ -664,48 +363,40 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
664
363
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
665
364
  is_prefill = query_length > 1
666
365
 
667
- # 1. main input
366
+ input_info = []
668
367
  if rbln_config.use_inputs_embeds:
669
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
368
+ input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], "float32"))
670
369
  else:
671
- main_input = ("input_ids", [batch_size, query_length], "int64")
672
-
673
- # 2. cache_position
674
- input_info = [
675
- main_input,
676
- (
677
- "cache_position",
678
- [batch_size, query_length],
679
- "int32",
680
- ),
681
- ]
370
+ input_info.append(("input_ids", [batch_size, query_length], "int64"))
371
+
372
+ input_info.append(("cache_position", [batch_size, query_length], "int32"))
682
373
 
683
- # 3. block_tables
684
374
  if rbln_config.use_global_attention:
685
375
  max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
686
- input_info.extend(
687
- [("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
376
+ input_info.append(
377
+ ("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
688
378
  )
689
379
  if rbln_config.use_local_attention:
690
- input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
380
+ input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
691
381
 
692
- # 4. query_position for sliding window attention
693
382
  if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
694
- input_info.extend([("query_position", [], "int16")])
383
+ input_info.append(("query_position", [], "int16"))
695
384
 
696
- # 5. attention_mask & position_ids
697
385
  if rbln_config.use_attention_mask:
698
- input_info.extend(
699
- [
700
- ("attention_mask", [batch_size, rbln_config.max_seq_len], "float32")
701
- if rbln_config.use_position_ids
702
- else ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
703
- ]
704
- )
386
+ if rbln_config.use_position_ids:
387
+ input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], "float32"))
388
+ else:
389
+ input_info.append(
390
+ ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
391
+ )
392
+
705
393
  if rbln_config.use_position_ids:
706
394
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
707
395
 
708
- # 6. past_key_values
396
+ kvcache_dtype = "float32"
397
+ if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
398
+ kvcache_dtype = "float8_e4m3fn"
399
+
709
400
  global_kvcache_shape = [
710
401
  rbln_config.kvcache_num_blocks,
711
402
  num_key_value_heads,
@@ -720,7 +411,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
720
411
  local_kvcache_shape
721
412
  if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
722
413
  else global_kvcache_shape,
723
- "float32",
414
+ kvcache_dtype,
724
415
  )
725
416
  for i in range(num_hidden_layers * 2)
726
417
  ]
@@ -784,15 +475,62 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
784
475
  max_seq_len=rbln_config.max_seq_len,
785
476
  )
786
477
 
787
- if rbln_config.kvcache_num_blocks is None:
788
- rbln_config.kvcache_num_blocks = (
789
- rbln_config.max_seq_len // rbln_config.kvcache_block_size
790
- ) * rbln_config.batch_size
478
+ num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
791
479
 
792
- return rbln_config
480
+ # Update kvcache_num_blocks based on the attention implementation.
481
+ if rbln_config.attn_impl == "flash_attn":
482
+ estimated_max_num_blocks = cls.get_maximum_num_blocks(
483
+ config=model_config,
484
+ tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
485
+ kvcache_block_size=rbln_config.kvcache_block_size,
486
+ nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
487
+ n_model_params=sum(p.numel() for p in model.parameters()),
488
+ num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
489
+ )
793
490
 
794
- @classmethod
795
- def _update_rbln_config(
491
+ if rbln_config.kvcache_num_blocks is None:
492
+ if estimated_max_num_blocks < num_full_blocks:
493
+ # lower bound of the number of blocks for flash attention.
494
+ min_blocks_for_flash = min(
495
+ rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
496
+ )
497
+ if min_blocks_for_flash > estimated_max_num_blocks:
498
+ # NOTE: Just try to compile with lower bound of blocks for flash attention.
499
+ # Even if it's larger than the estimated maximum number of blocks.
500
+ rbln_config.kvcache_num_blocks = min_blocks_for_flash
501
+ else:
502
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
503
+ rbln_config.kvcache_num_blocks = estimated_max_num_blocks
504
+
505
+ if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
506
+ raise RuntimeError(
507
+ f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
508
+ "Ensure the number of blocks is at least equal to the batch size."
509
+ )
510
+ else:
511
+ rbln_config.kvcache_num_blocks = num_full_blocks
512
+ elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
513
+ logger.warning(
514
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
515
+ f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
516
+ "This can cause a failure during model compilation."
517
+ )
518
+ else:
519
+ if rbln_config.kvcache_num_blocks is None:
520
+ rbln_config.kvcache_num_blocks = num_full_blocks
521
+ elif rbln_config.kvcache_num_blocks > num_full_blocks:
522
+ logger.warning(
523
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
524
+ f" than the required number of blocks ({num_full_blocks})."
525
+ "This can cause a failure during model compilation."
526
+ )
527
+
528
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
529
+
530
+ return rbln_config
531
+
532
+ @classmethod
533
+ def _update_rbln_config(
796
534
  cls,
797
535
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
798
536
  model: Optional[PreTrainedModel] = None,
@@ -823,7 +561,20 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
823
561
  )
824
562
 
825
563
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
826
- rbln_config.set_compile_cfgs([prefill_compile_config])
564
+ compile_cfgs = [prefill_compile_config]
565
+
566
+ if rbln_config.can_generate:
567
+ for batch_size in rbln_config.decoder_batch_sizes:
568
+ dec_input_info = cls.get_input_info(
569
+ batch_size=batch_size,
570
+ query_length=1,
571
+ rbln_config=rbln_config,
572
+ model_config=model_config,
573
+ )
574
+ compile_cfgs.append(
575
+ RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
576
+ )
577
+ rbln_config.set_compile_cfgs(compile_cfgs)
827
578
 
828
579
  return rbln_config
829
580
 
@@ -833,128 +584,37 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
833
584
  compiled_models: List[rebel.RBLNCompiledModel],
834
585
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
835
586
  ) -> List[rebel.Runtime]:
836
- expected_model_names = [
837
- "prefill",
838
- ]
587
+ expected_model_names = ["prefill"]
588
+ if rbln_config.can_generate:
589
+ expected_model_names.extend(
590
+ [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
591
+ )
839
592
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
840
593
  cls._raise_missing_compiled_file_error(expected_model_names)
841
594
 
842
- return [
595
+ ret_val = [
843
596
  rebel.Runtime(
844
597
  compiled_models[0],
845
598
  tensor_type="pt",
846
599
  device=rbln_config.device_map["prefill"],
847
600
  activate_profiler=rbln_config.activate_profiler,
848
- ),
849
- ]
850
-
851
- def _preprocess_chunked_prefill(
852
- self,
853
- inputs: torch.Tensor,
854
- attention_mask: Optional[torch.Tensor] = None,
855
- position_embed: Optional[torch.Tensor] = None,
856
- ):
857
- # valid sequence length of inputs_embeds
858
- query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
859
-
860
- # extract valid inputs
861
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
862
-
863
- if inputs.dim() == 2 and self.rbln_config.use_inputs_embeds:
864
- inputs = self.get_input_embeddings()(inputs)
865
-
866
- if position_embed is not None:
867
- position_embed = (
868
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
869
- )
870
-
871
- # padding for chunked prefill
872
- padding_size = (
873
- self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
874
- ) % self.rbln_config.prefill_chunk_size
875
- padded_len = query_length + padding_size
876
-
877
- inputs = (
878
- torch.nn.functional.pad(inputs, (0, padding_size))
879
- if not self.rbln_config.use_inputs_embeds
880
- else torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
881
- )
882
- position_embed = (
883
- None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
884
- )
885
- cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
886
-
887
- chunked_attention_mask = (
888
- torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
889
- if self.rbln_config.use_attention_mask
890
- else None
891
- )
892
-
893
- return inputs, position_embed, cache_position, query_length, chunked_attention_mask
894
-
895
- def _chunked_prefill_forward(
896
- self,
897
- inputs: torch.Tensor,
898
- attention_mask: Optional[torch.Tensor] = None,
899
- position_embed: Optional[torch.Tensor] = None,
900
- ):
901
- padded_input, padded_position_embed, cache_position, query_length, chunked_attention_mask = (
902
- self._preprocess_chunked_prefill(inputs, attention_mask, position_embed)
903
- )
904
-
905
- # chunked prefill
906
- last_hidden_states = []
907
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
908
- # Extract the current chunk of inputs and cache positions
909
- input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
910
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
911
-
912
- valid_length = (
913
- self.rbln_config.prefill_chunk_size
914
- if (step + self.rbln_config.prefill_chunk_size) <= query_length
915
- else query_length - step
916
- )
917
- if self.rbln_config.use_local_attention:
918
- query_position = torch.tensor(valid_length - 1, dtype=torch.int16)
919
- else:
920
- query_position = None
921
-
922
- if self.rbln_config.use_attention_mask:
923
- if step > 0:
924
- chunked_attention_mask[:, :, :, :step] = 1
925
- chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
926
-
927
- # Forward pass for the current chunk
928
- last_hidden_states_chunk = self.prefill_decoder(
929
- input_ids=input_chunk if not self.rbln_config.use_inputs_embeds else None,
930
- inputs_embeds=input_chunk if self.rbln_config.use_inputs_embeds else None,
931
- cache_position=cache_pos_chunk,
932
- block_tables=self.block_tables if self.rbln_config.use_global_attention else None,
933
- local_block_tables=self.local_block_tables if self.rbln_config.use_local_attention else None,
934
- query_position=query_position,
935
- attention_mask=chunked_attention_mask,
936
- position_emb=padded_position_embed,
601
+ timeout=rbln_config.timeout,
937
602
  )
938
- last_hidden_states.append(last_hidden_states_chunk)
939
- last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
940
-
941
- return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
942
-
943
- def _postprocess_chunked_prefill(
944
- self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
945
- ):
946
- # index copy for attention mask
947
- if attention_mask is not None:
948
- new_last_hidden_states = torch.full(
949
- (1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
950
- fill_value=1e-10,
951
- dtype=last_hidden_states.dtype,
603
+ ]
604
+ if rbln_config.can_generate:
605
+ ret_val.extend(
606
+ [
607
+ rebel.Runtime(
608
+ compiled_models[i + 1],
609
+ tensor_type="pt",
610
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
611
+ activate_profiler=rbln_config.activate_profiler,
612
+ timeout=rbln_config.timeout,
613
+ )
614
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
615
+ ]
952
616
  )
953
- mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
954
- new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
955
- else:
956
- new_last_hidden_states = last_hidden_states
957
- return new_last_hidden_states
617
+ return ret_val
958
618
 
959
619
  def forward(
960
620
  self,
@@ -966,20 +626,32 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
966
626
  ) -> Tuple[torch.FloatTensor]:
967
627
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
968
628
  batch_size = inputs.shape[0]
629
+
630
+ if batch_size != self.rbln_config.batch_size:
631
+ raise ValueError(
632
+ f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
633
+ )
634
+
969
635
  all_last_hidden_states = []
970
- for b_idx in range(batch_size):
971
- last_hidden_states = self._chunked_prefill_forward(
972
- inputs[b_idx : b_idx + 1],
973
- attention_mask[b_idx] if attention_mask is not None else None,
974
- position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
636
+ for b_idx in range(self.rbln_config.batch_size):
637
+ query_length = (
638
+ attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
975
639
  )
640
+ cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
641
+ last_hidden_states = self.prefill_decoder(
642
+ inputs[b_idx : b_idx + 1],
643
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
644
+ position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
645
+ cache_position=cache_position,
646
+ batch_idx=b_idx,
647
+ ).logits
976
648
  all_last_hidden_states.append(last_hidden_states)
977
649
 
978
650
  last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
979
651
  return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
980
652
 
981
653
 
982
- class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
654
+ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
983
655
  """
984
656
  A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
985
657
  This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
@@ -1002,380 +674,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
1002
674
 
1003
675
  auto_model_class = AutoModelForCausalLM
1004
676
 
1005
- def __post_init__(self, **kwargs):
1006
- main_input_name = self.main_input_name
1007
-
1008
- if self.rbln_config.use_inputs_embeds:
1009
- main_input_name = "inputs_embeds"
1010
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
1011
- self.embed_tokens = self._create_embedding_layer()
1012
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
1013
- else:
1014
- self.embed_tokens = None
1015
-
1016
- # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
1017
- dec_attn_mask = torch.zeros(
1018
- self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
1019
- )
1020
- block_tables = torch.zeros(
1021
- self.rbln_config.batch_size,
1022
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
1023
- dtype=torch.int16,
1024
- ).fill_(-1)
1025
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
1026
-
1027
- self.prefill_decoder = RBLNRuntimeModel(
1028
- runtime=self.model[0],
1029
- main_input_name=main_input_name,
1030
- embed_tokens=self.embed_tokens,
1031
- phase="prefill",
1032
- batch_size=self.rbln_config.batch_size,
1033
- dec_attn_mask=dec_attn_mask,
1034
- block_tables=block_tables,
1035
- free_block_pool=free_block_pool,
1036
- rbln_config=self.rbln_config,
1037
- vocab_size=self.config.vocab_size,
1038
- )
1039
-
1040
- if self.can_generate():
1041
- self.decoders = {}
1042
- for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
1043
- self.decoders[batch_size] = RBLNRuntimeModel(
1044
- runtime=self.model[i + 1],
1045
- main_input_name=main_input_name,
1046
- embed_tokens=self.embed_tokens,
1047
- phase="decode",
1048
- batch_size=batch_size,
1049
- dec_attn_mask=dec_attn_mask,
1050
- block_tables=block_tables,
1051
- free_block_pool=free_block_pool,
1052
- rbln_config=self.rbln_config,
1053
- )
1054
-
1055
- # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
1056
- self.decoder = self.decoders[self.rbln_config.batch_size]
1057
-
1058
- @classmethod
1059
- def get_quantized_model(
1060
- cls,
1061
- model_id: str,
1062
- config: Optional[PretrainedConfig] = None,
1063
- use_auth_token: Optional[Union[bool, str]] = None,
1064
- revision: Optional[str] = None,
1065
- force_download: bool = False,
1066
- cache_dir: Optional[str] = None,
1067
- subfolder: str = "",
1068
- local_files_only: bool = False,
1069
- trust_remote_code: bool = False,
1070
- **kwargs,
1071
- ):
1072
- kwargs = cls.update_kwargs(kwargs)
1073
-
1074
- if config is None:
1075
- config = AutoConfig.from_pretrained(
1076
- model_id,
1077
- use_auth_token=use_auth_token,
1078
- revision=revision,
1079
- force_download=force_download,
1080
- cache_dir=cache_dir,
1081
- trust_remote_code=trust_remote_code,
1082
- **kwargs,
1083
- )
1084
-
1085
- with no_init_weights():
1086
- model = AutoModelForCausalLM.from_config(config)
1087
-
1088
- model = prepare_model_for_quantization(
1089
- model,
1090
- model_id,
1091
- kwargs.get("num_hidden_layers"),
1092
- use_auth_token=use_auth_token,
1093
- revision=revision,
1094
- cache_dir=cache_dir,
1095
- force_download=force_download,
1096
- local_files_only=local_files_only,
1097
- )
1098
- return model
1099
-
1100
- def __getattr__(self, __name: str) -> Any:
1101
- # Special method to delegate attribute access to the original Huggingface LM class.
1102
- # This method is called when an attribute is not found in the current instance's dictionary.
1103
- # It enables transparent access to the original model's attributes and methods while maintaining
1104
- # proper method binding.
1105
-
1106
- # The method implements a delegation pattern that:
1107
-
1108
- # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
1109
- # 2. For other attributes: Returns them directly from the original class
1110
-
1111
- def redirect(func):
1112
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
1113
-
1114
- val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
1115
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
1116
- return redirect(val)
1117
- return val
1118
-
1119
- @classmethod
1120
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
1121
- wrapper_cfg = {
1122
- "max_seq_len": rbln_config.max_seq_len,
1123
- "attn_impl": rbln_config.attn_impl,
1124
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
1125
- "kvcache_block_size": rbln_config.kvcache_block_size,
1126
- "use_rotary_emb": cls._use_rotary_emb,
1127
- "use_attention_mask": rbln_config.use_attention_mask,
1128
- "use_position_ids": rbln_config.use_position_ids,
1129
- "use_inputs_embeds": rbln_config.use_inputs_embeds,
1130
- "cache_impl": rbln_config.cache_impl,
1131
- "sliding_window": rbln_config.sliding_window,
1132
- "sliding_window_layers": rbln_config.sliding_window_layers,
1133
- }
1134
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
1135
-
1136
- @classmethod
1137
- @torch.inference_mode()
1138
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
1139
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
1140
- prefill_compile_config = rbln_config.compile_cfgs[0]
1141
-
1142
- # Here we use meta tensor, for the memory efficiency.
1143
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
1144
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
1145
- context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
1146
-
1147
- compiled_models = {}
1148
- compiled_models["prefill"] = cls._compile_model(
1149
- wrapped_model,
1150
- prefill_compile_config,
1151
- prefill_example_inputs,
1152
- context,
1153
- rbln_config,
1154
- rbln_config.quantization,
1155
- phase="prefill",
677
+ @property
678
+ def prefill_output_size(self):
679
+ return (
680
+ 1,
681
+ self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
682
+ self.config.vocab_size,
1156
683
  )
1157
684
 
1158
- if rbln_config.can_generate:
1159
- wrapped_model.phase = "decode"
1160
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
1161
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
1162
- compiled_decoder = cls._compile_model(
1163
- wrapped_model,
1164
- dec_compile_config,
1165
- dec_example_inputs,
1166
- context,
1167
- rbln_config,
1168
- rbln_config.quantization,
1169
- phase="decode",
1170
- )
1171
- compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
1172
-
1173
- # check if the memory is enough to have additional blocks
1174
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1175
- if rbln_config.kvcache_num_blocks < required_num_blocks:
1176
- cls.maybe_suggest_kvcache_num_blocks(
1177
- compiled_models=compiled_models,
1178
- model_config=model.config,
1179
- rbln_config=rbln_config,
1180
- )
1181
-
1182
- return compiled_models
1183
-
1184
685
  @classmethod
1185
686
  def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
1186
687
  return is_prefill
1187
688
 
1188
- @classmethod
1189
- def _update_attention_config(
1190
- cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
1191
- ):
1192
- rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
1193
- attn_impl=rbln_config.attn_impl,
1194
- kvcache_partition_len=rbln_config.kvcache_partition_len,
1195
- kvcache_block_size=rbln_config.kvcache_block_size,
1196
- max_seq_len=rbln_config.max_seq_len,
1197
- )
1198
-
1199
- validate_attention_method(
1200
- attn_impl=rbln_config.attn_impl,
1201
- kvcache_partition_len=rbln_config.kvcache_partition_len,
1202
- kvcache_block_size=rbln_config.kvcache_block_size,
1203
- max_seq_len=rbln_config.max_seq_len,
1204
- )
1205
-
1206
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1207
- max_num_blocks = required_num_blocks
1208
-
1209
- if rbln_config.attn_impl == "flash_attn":
1210
- estimated_max_num_blocks = cls.get_maximum_num_blocks(
1211
- config=model_config,
1212
- tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
1213
- kvcache_block_size=rbln_config.kvcache_block_size,
1214
- nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
1215
- n_model_params=sum(p.numel() for p in model.parameters()),
1216
- num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
1217
- )
1218
-
1219
- max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
1220
-
1221
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
1222
- if rbln_config.batch_size > 1 and max_num_blocks < flash_min_blocks:
1223
- max_num_blocks = flash_min_blocks
1224
-
1225
- if max_num_blocks < rbln_config.batch_size:
1226
- raise RuntimeError(
1227
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
1228
- "Ensure the number of blocks is at least equal to the batch size."
1229
- )
1230
-
1231
- if rbln_config.kvcache_num_blocks is None:
1232
- rbln_config.kvcache_num_blocks = max_num_blocks
1233
- elif rbln_config.kvcache_num_blocks > max_num_blocks:
1234
- logger.warning(
1235
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
1236
- f" than the estimated maximum number of blocks ({max_num_blocks})."
1237
- "This can cause a failure during model compilation."
1238
- )
1239
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
1240
-
1241
- return rbln_config
1242
-
1243
- @classmethod
1244
- def _update_rbln_config(
1245
- cls,
1246
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
1247
- model: Optional[PreTrainedModel] = None,
1248
- model_config: Optional[PretrainedConfig] = None,
1249
- rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
1250
- ) -> RBLNDecoderOnlyModelForCausalLMConfig:
1251
- rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
1252
- if rbln_config.can_generate:
1253
- compile_configs = rbln_config.compile_cfgs
1254
- for batch_size in rbln_config.decoder_batch_sizes:
1255
- dec_input_info = cls.get_input_info(
1256
- batch_size=batch_size,
1257
- query_length=1,
1258
- rbln_config=rbln_config,
1259
- model_config=model_config,
1260
- )
1261
- compile_configs.append(
1262
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1263
- )
1264
- rbln_config.set_compile_cfgs(compile_configs)
1265
-
1266
- return rbln_config
1267
-
1268
- @classmethod
1269
- def _create_runtimes(
1270
- cls,
1271
- compiled_models: List[rebel.RBLNCompiledModel],
1272
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
1273
- ) -> List[rebel.Runtime]:
1274
- expected_model_names = ["prefill"]
1275
- if rbln_config.can_generate:
1276
- expected_model_names.extend(
1277
- [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
1278
- )
1279
- if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
1280
- cls._raise_missing_compiled_file_error(expected_model_names)
1281
-
1282
- ret_val = [
1283
- rebel.Runtime(
1284
- compiled_models[0],
1285
- tensor_type="pt",
1286
- device=rbln_config.device_map["prefill"],
1287
- activate_profiler=rbln_config.activate_profiler,
1288
- timeout=rbln_config.timeout,
1289
- )
1290
- ]
1291
- if rbln_config.can_generate:
1292
- ret_val.extend(
1293
- [
1294
- rebel.Runtime(
1295
- compiled_models[i + 1],
1296
- tensor_type="pt",
1297
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1298
- activate_profiler=rbln_config.activate_profiler,
1299
- timeout=rbln_config.timeout,
1300
- )
1301
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1302
- ]
1303
- )
1304
- return ret_val
1305
-
1306
- def get_decoder(self):
1307
- if not self.can_generate():
1308
- raise ValueError("Decode stage is not supported in this model.")
1309
- return self.decoder
1310
-
1311
- def can_generate(self):
1312
- return self.rbln_config.can_generate
1313
-
1314
- def _reorder_cache(self, past_key_values, beam_idx):
1315
- raise NotImplementedError
1316
-
1317
- def prepare_inputs_for_generation(
1318
- self,
1319
- input_ids: torch.LongTensor,
1320
- generate_idx: Optional[torch.Tensor] = None,
1321
- attention_mask: Optional[torch.LongTensor] = None,
1322
- inputs_embeds: Optional[torch.Tensor] = None,
1323
- padded_cache_lengths: Optional[torch.Tensor] = None,
1324
- **kwargs,
1325
- ):
1326
- model_inputs = {}
1327
- is_prefill_phase = generate_idx is None
1328
-
1329
- if is_prefill_phase:
1330
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1331
- padded_cache_lengths = torch.zeros_like(generate_idx)
1332
- cache_position = None
1333
- position_ids = None
1334
- else:
1335
- if inputs_embeds is not None:
1336
- # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
1337
- inputs_embeds = None
1338
-
1339
- input_ids = input_ids[:, -1:]
1340
- position_ids = generate_idx
1341
- cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
1342
- generate_idx = generate_idx + 1
1343
- model_inputs.update({"input_ids": input_ids})
1344
-
1345
- if inputs_embeds is not None:
1346
- if self.rbln_config.use_inputs_embeds:
1347
- model_inputs.update({"inputs_embeds": inputs_embeds})
1348
- else:
1349
- raise ValueError(
1350
- "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
1351
- )
1352
- else:
1353
- model_inputs.update({"input_ids": input_ids})
1354
-
1355
- model_inputs.update(
1356
- {
1357
- "attention_mask": attention_mask,
1358
- "cache_position": cache_position,
1359
- "generate_idx": generate_idx,
1360
- "position_ids": position_ids,
1361
- "padded_cache_lengths": padded_cache_lengths,
1362
- }
1363
- )
1364
-
1365
- return model_inputs
1366
-
1367
- def _update_model_kwargs_for_generation(
1368
- self,
1369
- outputs: RBLNDecoderOnlyForCausalLMOutput,
1370
- model_kwargs: Dict[str, Any],
1371
- **kwargs,
1372
- ) -> Dict[str, Any]:
1373
- # update generate_idx
1374
- model_kwargs["generate_idx"] = outputs.generate_idx
1375
- model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
1376
-
1377
- return model_kwargs
1378
-
1379
689
  def forward(
1380
690
  self,
1381
691
  input_ids: Optional[torch.LongTensor] = None,
@@ -1441,6 +751,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
1441
751
  if not return_dict:
1442
752
  return logits, generate_idx, padded_cache_lengths
1443
753
  else:
1444
- return RBLNDecoderOnlyForCausalLMOutput(
754
+ return RBLNDecoderOnlyOutput(
1445
755
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
1446
756
  )