optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,504 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import deque
16
+ from typing import Any, Optional
17
+
18
+ import rebel
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ from ....utils.runtime_utils import RBLNPytorchRuntime
23
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
24
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
25
+
26
+
27
+ class RBLNPageTableManager:
28
+ EMPTY_BLOCK = -1
29
+ NO_BLOCKS_ERROR = (
30
+ "No memory blocks are available for allocation. "
31
+ "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
32
+ "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
33
+ "Using vllm-rbln should fix this issue and enhance inference performance."
34
+ )
35
+
36
+ def __init__(self, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
37
+ self.rbln_config = rbln_config
38
+ self.block_tables = torch.zeros(
39
+ self.rbln_config.batch_size,
40
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
41
+ dtype=torch.int16,
42
+ ).fill_(self.EMPTY_BLOCK)
43
+ self.free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
44
+
45
+ def update_block(self, batch_idx: int, block_idx: int):
46
+ """
47
+ If the block is empty (empty_block), allocates a block from the free_block_pool.
48
+ """
49
+ if self.block_tables[batch_idx][block_idx] == self.EMPTY_BLOCK:
50
+ if self.free_block_pool:
51
+ block = self.free_block_pool.popleft()
52
+ self.block_tables[batch_idx][block_idx] = block
53
+ else:
54
+ raise RuntimeError(self.NO_BLOCKS_ERROR)
55
+
56
+ def replace_empty_block(self, block_tables: torch.Tensor):
57
+ """
58
+ Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
59
+ """
60
+ if not torch.any(block_tables == self.EMPTY_BLOCK):
61
+ return block_tables.clone()
62
+ elif self.free_block_pool:
63
+ _free_block = self.free_block_pool[0]
64
+ return torch.where(block_tables == self.EMPTY_BLOCK, _free_block, block_tables)
65
+ else:
66
+ raise RuntimeError(self.NO_BLOCKS_ERROR)
67
+
68
+ def get_block_tables(
69
+ self, cache_position: torch.Tensor, batch_idx: int = None, batch_size: int = None, phase: str = "prefill"
70
+ ) -> torch.Tensor:
71
+ """
72
+ Manages and returns the KV cache block tables.
73
+ Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
74
+
75
+ Args:
76
+ cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
77
+ batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
78
+
79
+ Returns:
80
+ Updated block tables.
81
+ """
82
+
83
+ def get_global_block_tables():
84
+ if not self.rbln_config.use_global_attention:
85
+ return None
86
+
87
+ if phase == "prefill":
88
+ # Track previously used blocks and return them to the free_block_pool and
89
+ # reset the current batch's block table to empty blocks
90
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.EMPTY_BLOCK].tolist()
91
+ self.free_block_pool.extend(prev_blocks)
92
+ self.block_tables[batch_idx].fill_(self.EMPTY_BLOCK)
93
+
94
+ # Get the start (s) and end (e) positions from cache_position and
95
+ # iterate over the cache positions to allocate necessary blocks
96
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
97
+ for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
98
+ block_idx = position // self.rbln_config.kvcache_block_size
99
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
100
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
101
+ self.update_block(batch_idx, block_idx)
102
+
103
+ return self.replace_empty_block(self.block_tables[batch_idx])
104
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
105
+ else:
106
+ for b_idx in range(batch_size):
107
+ position = cache_position[b_idx][0].item()
108
+ block_idx = position // self.rbln_config.kvcache_block_size
109
+ self.update_block(b_idx, block_idx)
110
+
111
+ return self.replace_empty_block(self.block_tables)
112
+
113
+ def get_local_block_tables():
114
+ if not self.rbln_config.use_local_attention:
115
+ return None
116
+ else:
117
+ return (
118
+ torch.tensor([batch_idx], dtype=torch.int16)
119
+ if phase == "prefill"
120
+ else torch.arange(batch_size, dtype=torch.int16).view(batch_size, -1)
121
+ )
122
+
123
+ return get_global_block_tables(), get_local_block_tables()
124
+
125
+ # Whether block_tables and local_block_tables are provided by the user
126
+ def is_external_block_tables(
127
+ self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
128
+ ):
129
+ if self.rbln_config.cache_impl == "static" and block_tables is None:
130
+ return False
131
+ elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
132
+ return False
133
+ elif self.rbln_config.cache_impl == "hybrid":
134
+ if (block_tables is not None) != (local_block_tables is not None):
135
+ raise ValueError(
136
+ "Both block_tables and local_block_tables must be provided or neither of them must be provided."
137
+ )
138
+ elif block_tables is None and local_block_tables is None:
139
+ return False
140
+
141
+ return True
142
+
143
+ def get_block_tables_if_needed(
144
+ self,
145
+ batch_size,
146
+ cache_position: torch.Tensor,
147
+ batch_idx: int = None,
148
+ phase: str = "prefill",
149
+ block_tables: Optional[torch.Tensor] = None,
150
+ local_block_tables: Optional[torch.Tensor] = None,
151
+ ):
152
+ is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
153
+ if not is_external_block_tables:
154
+ block_tables, local_block_tables = self.get_block_tables(
155
+ cache_position, batch_idx=batch_idx, batch_size=batch_size, phase=phase
156
+ )
157
+
158
+ return block_tables, local_block_tables, is_external_block_tables
159
+
160
+
161
+ class RBLNRuntimeModel(RBLNPytorchRuntime):
162
+ mandatory_members = ["main_input_name", "embed_tokens"]
163
+
164
+ def __init__(
165
+ self,
166
+ runtime: rebel.Runtime,
167
+ phase: str,
168
+ batch_size: int,
169
+ dec_attn_mask: torch.Tensor,
170
+ page_table_manager: RBLNPageTableManager,
171
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
172
+ out_buffers: Optional[torch.Tensor] = None,
173
+ **kwargs: Any,
174
+ ) -> None:
175
+ super().__init__(runtime, **kwargs)
176
+ self.phase = phase
177
+ self.batch_size = batch_size
178
+ self.rbln_config = rbln_config
179
+
180
+ # shared resources between prefill and decode phase
181
+ self.dec_attn_mask = dec_attn_mask
182
+ self.page_table_manager = page_table_manager
183
+
184
+ if self.phase == "prefill":
185
+ self.out_buffers = out_buffers
186
+ self.causal_mask = 1 - torch.triu(
187
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
188
+ )
189
+
190
+ self.lora_int_ids = None
191
+
192
+ def inputs_embeddings_if_needed(
193
+ self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
194
+ ):
195
+ if input_ids is None and inputs_embeds is None:
196
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
197
+
198
+ if self.rbln_config.use_inputs_embeds:
199
+ return self.embed_tokens(input_ids) if inputs_embeds is None else inputs_embeds
200
+ else:
201
+ return input_ids
202
+
203
+ def forward(
204
+ self,
205
+ input_ids: Optional[torch.LongTensor] = None,
206
+ inputs_embeds: Optional[torch.Tensor] = None,
207
+ cache_position: torch.Tensor = None,
208
+ attention_mask: Optional[torch.Tensor] = None,
209
+ batch_idx: Optional[int] = None,
210
+ block_tables: Optional[torch.Tensor] = None,
211
+ position_embed: Optional[torch.Tensor] = None,
212
+ position_ids: Optional[torch.Tensor] = None,
213
+ token_type_ids: Optional[torch.Tensor] = None,
214
+ local_block_tables: Optional[torch.Tensor] = None,
215
+ lora_int_ids: Optional[torch.Tensor] = None,
216
+ ):
217
+ inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
218
+ block_tables, local_block_tables, is_external_block_tables = (
219
+ self.page_table_manager.get_block_tables_if_needed(
220
+ self.batch_size,
221
+ cache_position,
222
+ batch_idx=batch_idx,
223
+ phase=self.phase,
224
+ block_tables=block_tables,
225
+ local_block_tables=local_block_tables,
226
+ )
227
+ )
228
+
229
+ if self.phase == "decode":
230
+ return self.decode_forward(
231
+ inputs,
232
+ cache_position,
233
+ block_tables,
234
+ is_external_block_tables,
235
+ attention_mask=attention_mask,
236
+ position_embed=position_embed,
237
+ position_ids=position_ids,
238
+ local_block_tables=local_block_tables,
239
+ lora_int_ids=lora_int_ids,
240
+ )
241
+ else:
242
+ return self.prefill_forward(
243
+ inputs,
244
+ cache_position,
245
+ attention_mask,
246
+ batch_idx,
247
+ block_tables,
248
+ is_external_block_tables=is_external_block_tables,
249
+ position_embed=position_embed,
250
+ token_type_ids=token_type_ids,
251
+ local_block_tables=local_block_tables,
252
+ lora_int_ids=lora_int_ids,
253
+ )
254
+
255
+ def decode_forward(
256
+ self,
257
+ inputs: torch.Tensor,
258
+ cache_position: torch.Tensor = None,
259
+ block_tables: torch.Tensor = None,
260
+ is_external_block_tables: bool = None,
261
+ attention_mask: Optional[torch.Tensor] = None,
262
+ position_embed: Optional[torch.Tensor] = None,
263
+ position_ids: Optional[torch.Tensor] = None,
264
+ local_block_tables: Optional[torch.Tensor] = None,
265
+ lora_int_ids: Optional[torch.Tensor] = None,
266
+ ) -> torch.FloatTensor:
267
+ if self.rbln_config.use_lora and lora_int_ids is None:
268
+ if self.lora_int_ids is None:
269
+ raise ValueError(
270
+ "lora_int_id is required when using LoRA. "
271
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
272
+ )
273
+
274
+ lora_int_ids = self.lora_int_ids
275
+
276
+ if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
277
+ raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
278
+
279
+ if self.batch_size != cache_position.shape[0]:
280
+ raise RuntimeError(
281
+ f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
282
+ )
283
+
284
+ if self.rbln_config.use_attention_mask and attention_mask is None:
285
+ for b_idx in range(self.batch_size):
286
+ decoding_step = cache_position[b_idx].item()
287
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
288
+ raise ValueError(
289
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
290
+ )
291
+
292
+ if is_external_block_tables:
293
+ self.dec_attn_mask[b_idx].fill_(0)
294
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
295
+ else:
296
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
297
+
298
+ attention_mask = self.dec_attn_mask
299
+
300
+ logits = super().forward(
301
+ inputs,
302
+ cache_position,
303
+ block_tables,
304
+ local_block_tables,
305
+ position_embed,
306
+ attention_mask if self.rbln_config.use_attention_mask else None,
307
+ position_ids if self.rbln_config.use_position_ids else None,
308
+ lora_int_ids if self.rbln_config.use_lora else None,
309
+ )
310
+
311
+ return RBLNDecoderOnlyOutput(logits=logits)
312
+
313
+ def _prepare_prefill_inputs(
314
+ self,
315
+ inputs: torch.Tensor,
316
+ cache_position: Optional[torch.Tensor] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ position_embed: Optional[torch.Tensor] = None,
319
+ token_type_ids: Optional[torch.Tensor] = None,
320
+ ):
321
+ """
322
+ Prepare inputs for prefill phase.
323
+ """
324
+ # Handle continuous batching in a compiled graph by extracting valid inputs
325
+ # If an attention mask is provided, select only the valid (non-masked) inputs
326
+ if attention_mask is not None:
327
+ inputs = inputs[:, attention_mask.bool()]
328
+ position_embed = None if position_embed is None else position_embed[:, :, :, attention_mask.bool(), :]
329
+ token_type_ids = None if token_type_ids is None else token_type_ids[:, attention_mask.bool()]
330
+
331
+ query_length = inputs.shape[1]
332
+ if query_length > self.rbln_config.max_seq_len:
333
+ raise ValueError(
334
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
335
+ )
336
+
337
+ # Initialize attention mask for chunked processing
338
+ chunked_attention_mask = (
339
+ torch.zeros(
340
+ 1,
341
+ 1,
342
+ self.rbln_config.prefill_chunk_size,
343
+ self.rbln_config.max_seq_len,
344
+ dtype=self.rbln_config.torch_dtype,
345
+ )
346
+ if self.rbln_config.use_attention_mask
347
+ else None
348
+ )
349
+
350
+ cache_position = (
351
+ torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
352
+ )
353
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
354
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
355
+ if padding_size > 0:
356
+ inputs = (
357
+ F.pad(inputs, (0, 0, 0, padding_size))
358
+ if self.rbln_config.use_inputs_embeds
359
+ else F.pad(inputs, (0, padding_size))
360
+ )
361
+ position_embed = F.pad(position_embed, (0, 0, 0, padding_size)) if position_embed is not None else None
362
+ token_type_ids = F.pad(token_type_ids, (0, padding_size), value=-1) if token_type_ids is not None else None
363
+ cache_position = F.pad(cache_position, (0, padding_size))
364
+
365
+ # Overwrite position_ids and padded_cache_lengths
366
+ position_ids = cache_position.clone() if self.rbln_config.use_position_ids else None
367
+ padded_cache_lengths = 0
368
+
369
+ return (
370
+ inputs,
371
+ cache_position,
372
+ chunked_attention_mask,
373
+ position_ids,
374
+ position_embed,
375
+ padded_cache_lengths,
376
+ query_length,
377
+ token_type_ids,
378
+ )
379
+
380
+ def prefill_forward(
381
+ self,
382
+ inputs: torch.Tensor,
383
+ cache_position: Optional[torch.Tensor] = None,
384
+ attention_mask: Optional[torch.Tensor] = None,
385
+ batch_idx: Optional[int] = None,
386
+ block_tables: Optional[torch.Tensor] = None,
387
+ is_external_block_tables: Optional[bool] = None,
388
+ position_embed: Optional[torch.Tensor] = None,
389
+ token_type_ids: Optional[torch.Tensor] = None,
390
+ local_block_tables: Optional[torch.Tensor] = None,
391
+ lora_int_ids: Optional[torch.Tensor] = None,
392
+ ) -> torch.FloatTensor:
393
+ """
394
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
395
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
396
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
397
+ """
398
+ if self.rbln_config.use_lora and lora_int_ids is None:
399
+ if self.lora_int_ids is None:
400
+ raise ValueError(
401
+ "lora_int_id is required when using LoRA. "
402
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
403
+ )
404
+
405
+ if batch_idx is not None:
406
+ lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
407
+ else:
408
+ lora_int_ids = self.lora_int_ids.clone()
409
+
410
+ (
411
+ inputs,
412
+ cache_position,
413
+ chunked_attention_mask,
414
+ position_ids,
415
+ position_embed,
416
+ padded_cache_lengths,
417
+ query_length,
418
+ token_type_ids,
419
+ ) = self._prepare_prefill_inputs(
420
+ inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
421
+ )
422
+
423
+ # Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
424
+ prefix_cached_len = cache_position[0][0].item()
425
+ if prefix_cached_len > 0:
426
+ if prefix_cached_len % self.rbln_config.prefill_chunk_size != 0:
427
+ raise NotImplementedError(
428
+ "Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
429
+ )
430
+ if self.rbln_config.use_attention_mask:
431
+ chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
432
+
433
+ # Process input in chunks of size `prefill_chunk_size`
434
+ output_logits = []
435
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
436
+ s, e = step, step + self.rbln_config.prefill_chunk_size
437
+ # Extract the current chunk of inputs, cache positions, position ids, and position embeddings
438
+ input_chunk = inputs[:, s:e]
439
+ cache_pos_chunk = cache_position[:, s:e]
440
+ position_ids_chunk = position_ids[:, s:e] if self.rbln_config.use_position_ids else None
441
+ position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
442
+
443
+ # Update attention mask to ensure proper causal behavior
444
+ if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
445
+ if step > 0: # update previous chunk
446
+ chunked_attention_mask[
447
+ :,
448
+ :,
449
+ :,
450
+ s - self.rbln_config.prefill_chunk_size + prefix_cached_len : e
451
+ - self.rbln_config.prefill_chunk_size
452
+ + prefix_cached_len,
453
+ ] = 1
454
+ chunked_attention_mask[:, :, :, s + prefix_cached_len : e + prefix_cached_len] = self.causal_mask
455
+
456
+ # Calculate query position if needed
457
+ if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
458
+ query_position = (
459
+ torch.tensor((query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16)
460
+ if e >= query_length
461
+ else torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
462
+ )
463
+ else:
464
+ query_position = None
465
+
466
+ # Forward pass for the current chunk
467
+ output_logit = super().forward(
468
+ input_chunk,
469
+ cache_pos_chunk,
470
+ block_tables,
471
+ local_block_tables,
472
+ position_embed_chunk,
473
+ query_position,
474
+ chunked_attention_mask if self.rbln_config.use_attention_mask else None,
475
+ position_ids_chunk,
476
+ lora_int_ids if self.rbln_config.use_lora else None,
477
+ out=self.out_buffers,
478
+ )
479
+ output_logits.append(output_logit)
480
+
481
+ # Aggregate output_logits
482
+ output_logits = torch.concat(output_logits, dim=-2)
483
+ if self.rbln_config.logits_to_keep > 0:
484
+ output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
485
+ else:
486
+ output_logits = output_logits[:, :query_length, :]
487
+ # index copy for masked output_logits
488
+ if attention_mask is not None:
489
+ new_output_logits = torch.full(
490
+ (1, attention_mask.shape[-1], output_logits.shape[-1]),
491
+ fill_value=1e-10,
492
+ dtype=output_logits.dtype,
493
+ )
494
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
495
+ new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
496
+
497
+ output_logits = new_output_logits
498
+
499
+ # Update decoder attention mask with processed KV-cache length from prefill phase
500
+ if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
501
+ self.dec_attn_mask[batch_idx].fill_(0)
502
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
503
+
504
+ return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
@@ -0,0 +1,111 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, Optional
16
+
17
+ import torch
18
+ from transformers.generation.utils import GenerationMixin
19
+
20
+
21
+ if TYPE_CHECKING:
22
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
23
+
24
+
25
+ class RBLNDecoderOnlyGenerationMixin(GenerationMixin):
26
+ _supports_cache_class = False # Needed for GenerationMixin
27
+ _is_stateful = False # Needed for GenerationMixin
28
+
29
+ def _reorder_cache(self, past_key_values, beam_idx):
30
+ raise NotImplementedError
31
+
32
+ def prepare_inputs_for_generation(
33
+ self,
34
+ input_ids: torch.LongTensor,
35
+ generate_idx: Optional[torch.Tensor] = None,
36
+ attention_mask: Optional[torch.LongTensor] = None,
37
+ inputs_embeds: Optional[torch.Tensor] = None,
38
+ padded_cache_lengths: Optional[torch.Tensor] = None,
39
+ **kwargs,
40
+ ):
41
+ model_inputs = {}
42
+ is_prefill_phase = generate_idx is None
43
+
44
+ if is_prefill_phase:
45
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
46
+ padded_cache_lengths = torch.zeros_like(generate_idx)
47
+ cache_position = None
48
+ position_ids = None
49
+ else:
50
+ if inputs_embeds is not None:
51
+ # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
52
+ inputs_embeds = None
53
+
54
+ input_ids = input_ids[:, -1:]
55
+ position_ids = generate_idx
56
+ cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
57
+ generate_idx = generate_idx + 1
58
+ model_inputs.update({"input_ids": input_ids})
59
+
60
+ if inputs_embeds is not None:
61
+ if self.rbln_config.use_inputs_embeds:
62
+ model_inputs.update({"inputs_embeds": inputs_embeds})
63
+ else:
64
+ raise ValueError(
65
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
66
+ )
67
+ else:
68
+ model_inputs.update({"input_ids": input_ids})
69
+
70
+ model_inputs.update(
71
+ {
72
+ "attention_mask": attention_mask,
73
+ "cache_position": cache_position,
74
+ "generate_idx": generate_idx,
75
+ "position_ids": position_ids,
76
+ "padded_cache_lengths": padded_cache_lengths,
77
+ }
78
+ )
79
+
80
+ return model_inputs
81
+
82
+ def _update_model_kwargs_for_generation(
83
+ self, outputs: "RBLNDecoderOnlyOutput", model_kwargs: Dict[str, Any], **kwargs
84
+ ) -> Dict[str, Any]:
85
+ # update generate_idx
86
+ model_kwargs["generate_idx"] = outputs.generate_idx
87
+ model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
88
+ return model_kwargs
89
+
90
+ def generate(
91
+ self,
92
+ input_ids: torch.LongTensor,
93
+ attention_mask: Optional[torch.LongTensor] = None,
94
+ max_length: Optional[int] = None,
95
+ **kwargs,
96
+ ):
97
+ """
98
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
99
+
100
+ Args:
101
+ input_ids: The input ids to the model.
102
+ attention_mask: The attention mask to the model.
103
+ max_length: The maximum length of the generated text.
104
+ kwargs: Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
105
+ """
106
+ if max_length is not None:
107
+ kwargs["max_length"] = max_length
108
+ if attention_mask is not None:
109
+ kwargs["attention_mask"] = attention_mask
110
+
111
+ return super().generate(input_ids, **kwargs)