optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (90) hide show
  1. optimum/rbln/__init__.py +4 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +4 -4
  4. optimum/rbln/diffusers/__init__.py +1 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  22. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  23. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  27. optimum/rbln/modeling.py +2 -2
  28. optimum/rbln/modeling_base.py +12 -4
  29. optimum/rbln/ops/attn.py +158 -0
  30. optimum/rbln/ops/flash_attn.py +166 -0
  31. optimum/rbln/transformers/__init__.py +2 -0
  32. optimum/rbln/transformers/configuration_generic.py +4 -4
  33. optimum/rbln/transformers/modeling_generic.py +1 -4
  34. optimum/rbln/transformers/modeling_outputs.py +37 -0
  35. optimum/rbln/transformers/models/__init__.py +6 -16
  36. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  37. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  38. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  39. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  43. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  44. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  45. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  46. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  47. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  49. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
  51. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  52. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  53. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  54. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  55. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +19 -250
  56. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  57. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  58. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  59. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  60. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  61. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  64. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  65. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  66. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  67. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  68. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  69. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  70. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  71. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  72. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  73. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  74. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  75. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  76. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  77. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  78. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  79. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  80. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  81. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  82. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  83. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  84. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  85. optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
  86. optimum/rbln/utils/runtime_utils.py +3 -3
  87. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
  88. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
  89. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
  90. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
@@ -13,10 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from ...models.decoderonly.decoderonly_architecture import (
17
- DecoderOnlyModel,
18
- DecoderOnlyWrapper,
19
- )
16
+ from ...models.decoderonly.decoderonly_architecture import DecoderOnlyModel, DecoderOnlyWrapper
20
17
 
21
18
 
22
19
  class GemmaWrapper(DecoderOnlyWrapper):
@@ -90,6 +90,15 @@ class RBLNGemmaModel(RBLNDecoderOnlyModel):
90
90
 
91
91
  A class to convert and run pre-trained transformers based GemmaModel model on RBLN devices.
92
92
  It implements the methods to convert a pre-trained transformers GemmaModel model into a RBLN transformer model by:
93
+
94
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
95
+ - compiling the resulting graph using the RBLN compiler.
96
+
97
+ **Configuration:**
98
+ This model uses [`RBLNGemmaModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
99
+ the `rbln_config` parameter should be an instance of [`RBLNGemmaModelConfig`] or a dictionary conforming to its structure.
100
+
101
+ See the [`RBLNGemmaModelConfig`] class for all available configuration options.
93
102
  """
94
103
 
95
104
  _decoder_wrapper_cls = GemmaWrapper
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict, Optional
14
+ from typing import Any, Optional
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
17
  from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
@@ -25,7 +25,7 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
25
25
  use_attention_mask: Optional[bool] = None,
26
26
  prefill_chunk_size: Optional[int] = None,
27
27
  image_prefill_chunk_size: Optional[int] = None,
28
- **kwargs: Dict[str, Any],
28
+ **kwargs: Any,
29
29
  ):
30
30
  # use_attention_mask and use_position_ids are always True for Gemma3
31
31
  use_attention_mask = use_attention_mask or True
@@ -57,7 +57,7 @@ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
57
57
  batch_size: Optional[int] = None,
58
58
  vision_tower: Optional[RBLNModelConfig] = None,
59
59
  language_model: Optional[RBLNModelConfig] = None,
60
- **kwargs: Dict[str, Any],
60
+ **kwargs: Any,
61
61
  ):
62
62
  """
63
63
  Args:
@@ -0,0 +1,217 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ import rebel
17
+ import torch
18
+
19
+ from ...modeling_outputs import RBLNDecoderOnlyOutput, RBLNGemma3ForCausalLMOutput
20
+ from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
21
+
22
+
23
+ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
24
+ def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ self.image_prefill = image_prefill # FIXME(taehoon)
27
+ self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
28
+ self.decode = self.runtime if self.phase == "decode" else None
29
+
30
+ def _prepare_prefill_inputs(self, *args, **kwargs):
31
+ (
32
+ inputs,
33
+ cache_position,
34
+ chunked_attention_mask,
35
+ position_ids,
36
+ position_embed,
37
+ padded_cache_lengths,
38
+ query_length,
39
+ token_type_ids,
40
+ ) = super()._prepare_prefill_inputs(*args, **kwargs)
41
+
42
+ # chunked_attention_mask shape
43
+ chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
44
+
45
+ # In case of Gemma3ForConditionalGeneration, the loop counter may not be a prefill_chunk_size,
46
+ # so we cannot guarantee that the last chunk starts at a position that is a multiple of prefill_chunk_size.
47
+ if self.rbln_config.use_image_prefill:
48
+ padding_size = self.rbln_config.image_prefill_chunk_size
49
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
50
+ cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
51
+ position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
52
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
53
+
54
+ return (
55
+ inputs,
56
+ cache_position,
57
+ chunked_attention_mask,
58
+ position_ids,
59
+ position_embed,
60
+ padded_cache_lengths,
61
+ query_length,
62
+ token_type_ids,
63
+ )
64
+
65
+ def prefill_forward(
66
+ self,
67
+ inputs: torch.Tensor,
68
+ cache_position: torch.Tensor = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ batch_idx: int = None,
71
+ block_tables: torch.Tensor = None,
72
+ is_external_block_tables: bool = None,
73
+ position_embed: Optional[torch.Tensor] = None,
74
+ token_type_ids: Optional[torch.Tensor] = None,
75
+ local_block_tables: Optional[torch.Tensor] = None,
76
+ ) -> torch.FloatTensor:
77
+ """
78
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
79
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
80
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
81
+ """
82
+ (
83
+ inputs,
84
+ cache_position,
85
+ chunked_attention_mask,
86
+ position_ids,
87
+ position_embed,
88
+ padded_cache_lengths,
89
+ query_length,
90
+ token_type_ids,
91
+ ) = self._prepare_prefill_inputs(
92
+ inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
93
+ )
94
+
95
+ step = 0
96
+ while step < query_length:
97
+ if self.rbln_config.use_image_prefill:
98
+ # Check if the prefill chunk is an image prefill
99
+ is_image_prefill = torch.all(
100
+ token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
101
+ )
102
+ # Check if the prefill chunk is a text prefill which have image_tokens in it.
103
+ is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
104
+ token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
105
+ )
106
+ else:
107
+ is_image_prefill, is_text_prefill_with_image_tokens = False, False
108
+
109
+ # Check if the prefill chunk is the last chunk
110
+ is_last_chunk = step + self.rbln_config.prefill_chunk_size >= query_length
111
+
112
+ input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
113
+ cache_pos_chunk = (
114
+ cache_position[:, step : step + self.rbln_config.prefill_chunk_size] + padded_cache_lengths
115
+ )
116
+ position_ids_chunk = position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
117
+
118
+ # if text_prefill end with image_tokens, we only treat the text part.
119
+ num_processed_tokens = self.rbln_config.prefill_chunk_size
120
+ current_padded_cache_lengths = 0
121
+ if is_text_prefill_with_image_tokens:
122
+ first_image_token_idx = torch.where(
123
+ token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
124
+ )[1][0]
125
+ num_processed_tokens = first_image_token_idx.item()
126
+ current_padded_cache_lengths = self.rbln_config.prefill_chunk_size - num_processed_tokens
127
+ if is_last_chunk:
128
+ num_processed_tokens = query_length - step
129
+
130
+ chunked_attention_mask[
131
+ :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
132
+ ] = 1
133
+ query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
134
+
135
+ if is_image_prefill:
136
+ logits = self.image_prefill(
137
+ input_chunk,
138
+ cache_pos_chunk,
139
+ block_tables,
140
+ local_block_tables,
141
+ query_position,
142
+ chunked_attention_mask,
143
+ position_ids_chunk,
144
+ )
145
+ else:
146
+ logits = self.prefill(
147
+ input_chunk,
148
+ cache_pos_chunk,
149
+ block_tables,
150
+ local_block_tables,
151
+ query_position,
152
+ chunked_attention_mask,
153
+ position_ids_chunk,
154
+ )
155
+
156
+ padded_cache_lengths += current_padded_cache_lengths
157
+ step += num_processed_tokens
158
+
159
+ if not is_external_block_tables:
160
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
161
+
162
+ return RBLNGemma3ForCausalLMOutput(
163
+ logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
164
+ )
165
+
166
+ def decode_forward(
167
+ self,
168
+ inputs: torch.Tensor,
169
+ cache_position: torch.Tensor = None,
170
+ block_tables: torch.Tensor = None,
171
+ is_external_block_tables: bool = None,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ position_embed: Optional[torch.Tensor] = None,
174
+ position_ids: Optional[torch.Tensor] = None,
175
+ local_block_tables: Optional[torch.Tensor] = None,
176
+ ) -> torch.FloatTensor:
177
+ batch_size = inputs.shape[0]
178
+ if batch_size != self.batch_size:
179
+ raise RuntimeError(
180
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
181
+ )
182
+
183
+ if batch_size != cache_position.shape[0]:
184
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
185
+
186
+ # FIXME(taehoon): how to handle pos_attn_mask with external block tables
187
+ if is_external_block_tables:
188
+ if attention_mask is None:
189
+ raise ValueError("attention_mask should be provided with external block tables.")
190
+ if local_block_tables is None:
191
+ raise ValueError("local_block_tables should be provided with external block tables.")
192
+ else:
193
+ local_block_tables = (
194
+ local_block_tables
195
+ if local_block_tables is not None
196
+ else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
197
+ )
198
+ if self.rbln_config.use_attention_mask and attention_mask is None:
199
+ for b_idx in range(batch_size):
200
+ decoding_step = cache_position[b_idx].item()
201
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
202
+ raise ValueError(
203
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
204
+ )
205
+ self.dec_attn_mask[b_idx, decoding_step] = 1
206
+
207
+ attention_mask = self.dec_attn_mask
208
+
209
+ if self.batch_size < block_tables.shape[0]:
210
+ block_tables = block_tables[: self.batch_size]
211
+
212
+ if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
213
+ attention_mask = attention_mask[: self.batch_size]
214
+
215
+ logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
216
+
217
+ return RBLNDecoderOnlyOutput(logits=logits)
@@ -12,43 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import inspect
15
- from collections import deque
16
- from dataclasses import dataclass
17
15
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
18
16
 
19
17
  import rebel
20
18
  import torch
21
19
  from rebel.compile_context import CompileContext
22
- from transformers import (
23
- AutoModelForImageTextToText,
24
- Gemma3ForConditionalGeneration,
25
- PretrainedConfig,
26
- PreTrainedModel,
27
- )
20
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
21
  from transformers.modeling_outputs import BaseModelOutputWithPooling
29
22
  from transformers.modeling_utils import no_init_weights
30
23
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
31
24
 
32
25
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
26
  from ....modeling import RBLNModel
27
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
28
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
34
29
  from ..decoderonly.modeling_decoderonly import (
35
- RBLNDecoderOnlyForCausalLMOutput,
36
30
  RBLNDecoderOnlyModelForCausalLM,
37
- RBLNRuntimeModel,
38
31
  )
39
32
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
40
33
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
34
+ from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
41
35
 
42
36
 
43
37
  if TYPE_CHECKING:
44
38
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
45
39
 
46
40
 
47
- @dataclass
48
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyForCausalLMOutput):
49
- attention_mask: Optional[torch.Tensor] = None
50
-
51
-
52
41
  class LoopVisionTower:
53
42
  def __init__(self, vision_tower: RBLNModel) -> None:
54
43
  self.vision_tower = vision_tower
@@ -201,7 +190,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
201
190
 
202
191
  def _update_model_kwargs_for_generation(
203
192
  self,
204
- outputs: RBLNDecoderOnlyForCausalLMOutput,
193
+ outputs: RBLNDecoderOnlyOutput,
205
194
  model_kwargs: Dict[str, Any],
206
195
  **kwargs,
207
196
  ) -> Dict[str, Any]:
@@ -298,7 +287,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
298
287
  padded_cache_lengths: Optional[torch.Tensor] = None,
299
288
  position_ids: Optional[torch.Tensor] = None,
300
289
  **lm_kwargs: Dict[str, Any],
301
- ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
290
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
302
291
  # prefill
303
292
  if cache_position is None:
304
293
  logits = []
@@ -339,213 +328,11 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
339
328
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
340
329
  ).logits
341
330
 
342
- return RBLNDecoderOnlyForCausalLMOutput(
331
+ return RBLNDecoderOnlyOutput(
343
332
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
344
333
  )
345
334
 
346
335
 
347
- class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
348
- def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
349
- super().__init__(*args, **kwargs)
350
- self.image_prefill = image_prefill # FIXME(taehoon)
351
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
352
- self.decode = self.runtime if self.phase == "decode" else None
353
-
354
- def _prepare_prefill_inputs(self, *args, **kwargs):
355
- (
356
- inputs,
357
- cache_position,
358
- chunked_attention_mask,
359
- out_buffers,
360
- position_ids,
361
- position_embed,
362
- padded_cache_lengths,
363
- query_length,
364
- token_type_ids,
365
- ) = super()._prepare_prefill_inputs(*args, **kwargs)
366
-
367
- # chunked_attention_mask shape
368
- chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
369
-
370
- # In case of Gemma3ForConditionalGeneration, the loop counter may not be a prefill_chunk_size,
371
- # so we cannot guarantee that the last chunk starts at a position that is a multiple of prefill_chunk_size.
372
- if self.rbln_config.use_image_prefill:
373
- padding_size = self.rbln_config.image_prefill_chunk_size
374
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
375
- cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
376
- position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
377
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
378
-
379
- return (
380
- inputs,
381
- cache_position,
382
- chunked_attention_mask,
383
- out_buffers,
384
- position_ids,
385
- position_embed,
386
- padded_cache_lengths,
387
- query_length,
388
- token_type_ids,
389
- )
390
-
391
- def prefill_forward(
392
- self,
393
- inputs: torch.Tensor,
394
- cache_position: torch.Tensor = None,
395
- attention_mask: Optional[torch.Tensor] = None,
396
- batch_idx: int = None,
397
- block_tables: torch.Tensor = None,
398
- is_external_block_tables: bool = None,
399
- position_embed: Optional[torch.Tensor] = None,
400
- token_type_ids: Optional[torch.Tensor] = None,
401
- local_block_tables: Optional[torch.Tensor] = None,
402
- ) -> torch.FloatTensor:
403
- """
404
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
405
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
406
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
407
- """
408
- (
409
- inputs,
410
- cache_position,
411
- chunked_attention_mask,
412
- out_buffers,
413
- position_ids,
414
- position_embed,
415
- padded_cache_lengths,
416
- query_length,
417
- token_type_ids,
418
- ) = self._prepare_prefill_inputs(
419
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
420
- )
421
-
422
- step = 0
423
- while step < query_length:
424
- if self.rbln_config.use_image_prefill:
425
- # Check if the prefill chunk is an image prefill
426
- is_image_prefill = torch.all(
427
- token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
428
- )
429
- # Check if the prefill chunk is a text prefill which have image_tokens in it.
430
- is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
431
- token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
432
- )
433
- else:
434
- is_image_prefill, is_text_prefill_with_image_tokens = False, False
435
-
436
- # Check if the prefill chunk is the last chunk
437
- is_last_chunk = step + self.rbln_config.prefill_chunk_size >= query_length
438
-
439
- input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
440
- cache_pos_chunk = (
441
- cache_position[:, step : step + self.rbln_config.prefill_chunk_size] + padded_cache_lengths
442
- )
443
- position_ids_chunk = position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
444
-
445
- # if text_prefill end with image_tokens, we only treat the text part.
446
- num_processed_tokens = self.rbln_config.prefill_chunk_size
447
- current_padded_cache_lengths = 0
448
- if is_text_prefill_with_image_tokens:
449
- first_image_token_idx = torch.where(
450
- token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
451
- )[1][0]
452
- num_processed_tokens = first_image_token_idx.item()
453
- current_padded_cache_lengths = self.rbln_config.prefill_chunk_size - num_processed_tokens
454
- if is_last_chunk:
455
- num_processed_tokens = query_length - step
456
-
457
- chunked_attention_mask[
458
- :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
459
- ] = 1
460
- query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
461
-
462
- if is_image_prefill:
463
- logits = self.image_prefill(
464
- input_chunk,
465
- cache_pos_chunk,
466
- block_tables,
467
- local_block_tables,
468
- query_position,
469
- chunked_attention_mask,
470
- position_ids_chunk,
471
- out=out_buffers,
472
- )
473
- else:
474
- logits = self.prefill(
475
- input_chunk,
476
- cache_pos_chunk,
477
- block_tables,
478
- local_block_tables,
479
- query_position,
480
- chunked_attention_mask,
481
- position_ids_chunk,
482
- out=out_buffers,
483
- )
484
-
485
- padded_cache_lengths += current_padded_cache_lengths
486
- step += num_processed_tokens
487
-
488
- if not is_external_block_tables:
489
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
490
-
491
- return RBLNGemma3ForCausalLMOutput(
492
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
493
- )
494
-
495
- def decode_forward(
496
- self,
497
- inputs: torch.Tensor,
498
- cache_position: torch.Tensor = None,
499
- block_tables: torch.Tensor = None,
500
- is_external_block_tables: bool = None,
501
- attention_mask: Optional[torch.Tensor] = None,
502
- position_embed: Optional[torch.Tensor] = None,
503
- position_ids: Optional[torch.Tensor] = None,
504
- local_block_tables: Optional[torch.Tensor] = None,
505
- ) -> torch.FloatTensor:
506
- batch_size = inputs.shape[0]
507
- if batch_size != self.batch_size:
508
- raise RuntimeError(
509
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
510
- )
511
-
512
- if batch_size != cache_position.shape[0]:
513
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
514
-
515
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
516
- if is_external_block_tables:
517
- if attention_mask is None:
518
- raise ValueError("attention_mask should be provided with external block tables.")
519
- if local_block_tables is None:
520
- raise ValueError("local_block_tables should be provided with external block tables.")
521
- else:
522
- local_block_tables = (
523
- local_block_tables
524
- if local_block_tables is not None
525
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
526
- )
527
- if self.rbln_config.use_attention_mask and attention_mask is None:
528
- for b_idx in range(batch_size):
529
- decoding_step = cache_position[b_idx].item()
530
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
531
- raise ValueError(
532
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
533
- )
534
- self.dec_attn_mask[b_idx, decoding_step] = 1
535
-
536
- attention_mask = self.dec_attn_mask
537
-
538
- if self.batch_size < block_tables.shape[0]:
539
- block_tables = block_tables[: self.batch_size]
540
-
541
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
542
- attention_mask = attention_mask[: self.batch_size]
543
-
544
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
545
-
546
- return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
547
-
548
-
549
336
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
550
337
  """
551
338
  The Gemma3 Model transformer with a language modeling head (linear layer) on top.
@@ -559,52 +346,34 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
559
346
 
560
347
  _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
561
348
 
562
- def __post_init__(self, **kwargs):
563
- main_input_name = self.main_input_name
564
-
565
- if self.rbln_config.use_inputs_embeds:
566
- main_input_name = "inputs_embeds"
567
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
568
- self.embed_tokens = self._create_embedding_layer()
569
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
570
- else:
571
- self.embed_tokens = None
572
-
349
+ def setup_runtime(self):
573
350
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
574
351
  dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
575
- block_tables = torch.zeros(
576
- self.rbln_config.batch_size,
577
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
578
- dtype=torch.int16,
579
- ).fill_(-1)
580
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
352
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
353
+
354
+ common_kwargs = {
355
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
356
+ "embed_tokens": self.embed_tokens,
357
+ "dec_attn_mask": dec_attn_mask,
358
+ "page_table_manager": page_table_manager,
359
+ "rbln_config": self.rbln_config,
360
+ }
581
361
 
582
362
  self.prefill_decoder = RBLNGemma3RuntimeModel(
583
363
  runtime=self.model[0],
584
364
  image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
585
- main_input_name=main_input_name,
586
- embed_tokens=self.embed_tokens,
587
365
  phase="prefill",
588
366
  batch_size=self.rbln_config.batch_size,
589
- dec_attn_mask=dec_attn_mask,
590
- block_tables=block_tables,
591
- vocab_size=self.config.vocab_size,
592
- free_block_pool=free_block_pool,
593
- rbln_config=self.rbln_config,
367
+ **common_kwargs,
594
368
  )
595
369
 
596
370
  self.decoders = {}
597
371
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
598
372
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
599
373
  runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
600
- main_input_name=main_input_name,
601
- embed_tokens=self.embed_tokens,
602
374
  phase="decode",
603
375
  batch_size=batch_size,
604
- dec_attn_mask=dec_attn_mask,
605
- block_tables=block_tables,
606
- free_block_pool=free_block_pool,
607
- rbln_config=self.rbln_config,
376
+ **common_kwargs,
608
377
  )
609
378
 
610
379
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -47,6 +47,8 @@ class RBLNGPT2Model(RBLNDecoderOnlyModel):
47
47
 
48
48
  A class to convert and run pre-trained transformers based GPT2Model model on RBLN devices.
49
49
  It implements the methods to convert a pre-trained transformers GPT2Model model into a RBLN transformer model by:
50
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
51
+ - compiling the resulting graph using the RBLN compiler.
50
52
  """
51
53
 
52
54
  _decoder_wrapper_cls = GPT2Wrapper
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -39,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
39
39
  batch_size: Optional[int] = None,
40
40
  vision_model: Optional[RBLNModelConfig] = None,
41
41
  text_model: Optional[RBLNModelConfig] = None,
42
- **kwargs: Dict[str, Any],
42
+ **kwargs: Any,
43
43
  ):
44
44
  """
45
45
  Args:
@@ -34,17 +34,11 @@ from transformers.models.idefics3.modeling_idefics3 import Idefics3CausalLMOutpu
34
34
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
35
35
  from ....modeling import RBLNModel
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
- from ..decoderonly.modeling_decoderonly import (
38
- RBLNDecoderOnlyForCausalLMOutput,
39
- )
37
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
40
38
 
41
39
 
42
40
  if TYPE_CHECKING:
43
- from transformers import (
44
- AutoFeatureExtractor,
45
- AutoProcessor,
46
- AutoTokenizer,
47
- )
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
48
42
 
49
43
 
50
44
  class RBLNRuntimeVisionModel(RBLNPytorchRuntime):
@@ -494,7 +488,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
494
488
  if not return_dict:
495
489
  return logits, generate_idx
496
490
  else:
497
- return RBLNDecoderOnlyForCausalLMOutput(
491
+ return RBLNDecoderOnlyOutput(
498
492
  logits=logits,
499
493
  generate_idx=generate_idx,
500
494
  )