optimum-rbln 0.7.5a1__py3-none-any.whl → 0.7.5rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. optimum/rbln/__init__.py +10 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -2
  4. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -2
  5. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -2
  6. optimum/rbln/modeling.py +53 -9
  7. optimum/rbln/modeling_base.py +22 -3
  8. optimum/rbln/transformers/__init__.py +10 -0
  9. optimum/rbln/transformers/modeling_generic.py +0 -19
  10. optimum/rbln/transformers/models/__init__.py +14 -0
  11. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  12. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  13. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +114 -19
  14. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +29 -10
  15. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  16. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  17. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  18. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
  19. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
  20. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
  21. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  22. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  23. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -0
  24. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -1
  25. optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
  26. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -2
  27. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -3
  28. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/METADATA +1 -1
  29. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/RECORD +31 -27
  30. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/WHEEL +0 -0
  31. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -38,6 +38,7 @@ _import_structure = {
38
38
  "RBLNAutoModelForCTC",
39
39
  "RBLNAutoModelForDepthEstimation",
40
40
  "RBLNAutoModelForImageClassification",
41
+ "RBLNAutoModelForImageTextToText",
41
42
  "RBLNAutoModelForMaskedLM",
42
43
  "RBLNAutoModelForQuestionAnswering",
43
44
  "RBLNAutoModelForSeq2SeqLM",
@@ -78,6 +79,10 @@ _import_structure = {
78
79
  "RBLNExaoneForCausalLMConfig",
79
80
  "RBLNGemmaForCausalLM",
80
81
  "RBLNGemmaForCausalLMConfig",
82
+ "RBLNGemma3ForCausalLM",
83
+ "RBLNGemma3ForCausalLMConfig",
84
+ "RBLNGemma3ForConditionalGeneration",
85
+ "RBLNGemma3ForConditionalGenerationConfig",
81
86
  "RBLNGPT2LMHeadModel",
82
87
  "RBLNGPT2LMHeadModelConfig",
83
88
  "RBLNIdefics3VisionTransformer",
@@ -259,6 +264,7 @@ if TYPE_CHECKING:
259
264
  RBLNAutoModelForCTC,
260
265
  RBLNAutoModelForDepthEstimation,
261
266
  RBLNAutoModelForImageClassification,
267
+ RBLNAutoModelForImageTextToText,
262
268
  RBLNAutoModelForMaskedLM,
263
269
  RBLNAutoModelForQuestionAnswering,
264
270
  RBLNAutoModelForSeq2SeqLM,
@@ -297,6 +303,10 @@ if TYPE_CHECKING:
297
303
  RBLNDPTForDepthEstimationConfig,
298
304
  RBLNExaoneForCausalLM,
299
305
  RBLNExaoneForCausalLMConfig,
306
+ RBLNGemma3ForCausalLM,
307
+ RBLNGemma3ForCausalLMConfig,
308
+ RBLNGemma3ForConditionalGeneration,
309
+ RBLNGemma3ForConditionalGenerationConfig,
300
310
  RBLNGemmaForCausalLM,
301
311
  RBLNGemmaForCausalLMConfig,
302
312
  RBLNGPT2LMHeadModel,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.5a1'
21
- __version_tuple__ = version_tuple = (0, 7, 5, 'a1')
20
+ __version__ = version = '0.7.5rc1'
21
+ __version_tuple__ = version_tuple = (0, 7, 5, 'rc1')
@@ -58,8 +58,7 @@ class _PriorTransformer(torch.nn.Module):
58
58
  class RBLNPriorTransformer(RBLNModel):
59
59
  hf_library_name = "diffusers"
60
60
  auto_model_class = PriorTransformer
61
- output_class = PriorTransformerOutput
62
- output_key = "predicted_image_embedding"
61
+ _output_class = PriorTransformerOutput
63
62
 
64
63
  def __post_init__(self, **kwargs):
65
64
  super().__post_init__(**kwargs)
@@ -61,8 +61,7 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
61
61
  class RBLNSD3Transformer2DModel(RBLNModel):
62
62
  hf_library_name = "diffusers"
63
63
  auto_model_class = SD3Transformer2DModel
64
- output_class = Transformer2DModelOutput
65
- output_key = "sample"
64
+ _output_class = Transformer2DModelOutput
66
65
 
67
66
  def __post_init__(self, **kwargs):
68
67
  super().__post_init__(**kwargs)
@@ -143,8 +143,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
143
143
  hf_library_name = "diffusers"
144
144
  auto_model_class = UNet2DConditionModel
145
145
  _rbln_config_class = RBLNUNet2DConditionModelConfig
146
- output_class = UNet2DConditionOutput
147
- output_key = "sample"
146
+ _output_class = UNet2DConditionOutput
148
147
 
149
148
  def __post_init__(self, **kwargs):
150
149
  super().__post_init__(**kwargs)
optimum/rbln/modeling.py CHANGED
@@ -14,7 +14,7 @@
14
14
 
15
15
  from pathlib import Path
16
16
  from tempfile import TemporaryDirectory
17
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
17
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union, get_args, get_origin, get_type_hints
18
18
 
19
19
  import rebel
20
20
  import torch
@@ -49,8 +49,7 @@ class RBLNModel(RBLNBaseModel):
49
49
  ```
50
50
  """
51
51
 
52
- output_class = None
53
- output_key = "last_hidden_state"
52
+ _output_class = None
54
53
 
55
54
  @classmethod
56
55
  def update_kwargs(cls, kwargs):
@@ -245,16 +244,61 @@ class RBLNModel(RBLNBaseModel):
245
244
  # Format output according to task requirements
246
245
  return self._prepare_output(output, return_dict)
247
246
 
247
+ @classmethod
248
+ def get_hf_output_class(cls):
249
+ """
250
+ Dynamically gets the output class from the corresponding HuggingFace model class.
251
+
252
+ Returns:
253
+ type: The appropriate output class from transformers or diffusers
254
+ """
255
+ if cls._output_class:
256
+ return cls._output_class
257
+
258
+ hf_class = cls.get_hf_class()
259
+ if hf_class is None:
260
+ raise ValueError(f"No HuggingFace model class found for {cls.__name__}")
261
+
262
+ hints = get_type_hints(hf_class.forward) if hasattr(hf_class, "forward") else {}
263
+ ret = hints.get("return")
264
+
265
+ if ret is not None:
266
+ candidates = get_args(ret) if get_origin(ret) is Union else (ret,)
267
+
268
+ for t in candidates:
269
+ if t is type(None): # Skip NoneType in Union
270
+ continue
271
+ mod = getattr(t, "__module__", "")
272
+ if "transformers" in mod or "diffusers" in mod:
273
+ cls._output_class = t
274
+ return t
275
+
276
+ # Fallback to BaseModelOutput
277
+ cls._output_class = BaseModelOutput
278
+ return BaseModelOutput
279
+
248
280
  def _prepare_output(self, output, return_dict):
249
281
  """
250
282
  Prepare model output based on return_dict flag.
251
283
  This method can be overridden by subclasses to provide task-specific output handling.
252
284
  """
285
+ tuple_output = (output,) if not isinstance(output, (tuple, list)) else output
253
286
  if not return_dict:
254
- return (output,) if not isinstance(output, (tuple, list)) else output
287
+ return tuple_output
255
288
  else:
256
- if self.output_class is None:
257
- return BaseModelOutput(last_hidden_state=output)
258
-
259
- # Create output with the appropriate class and key
260
- return self.output_class(**{self.output_key: output})
289
+ output_class = self.get_hf_output_class()
290
+ if hasattr(output_class, "loss"):
291
+ tuple_output = (None,) + tuple_output
292
+
293
+ # Truncate if we have too many outputs, otherwise use as is
294
+ if hasattr(output_class, "__annotations__"):
295
+ num_fields = len(output_class.__annotations__)
296
+ if len(tuple_output) > num_fields:
297
+ tuple_output = tuple_output[:num_fields]
298
+ logger.warning(
299
+ f"Truncating output to {num_fields} fields for {output_class.__name__}. "
300
+ f"Expected {num_fields} fields, but got {len(tuple_output)} fields."
301
+ "This is unexpected. Please report this issue to the developers."
302
+ )
303
+
304
+ return output_class(*tuple_output)
@@ -178,9 +178,27 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
178
178
  return str(model_path)
179
179
 
180
180
  @classmethod
181
- def _load_compiled_models(cls, model_path: str):
181
+ def _load_compiled_models(cls, model_path: str, expected_compiled_model_names: List[str]):
182
182
  compiled_models = Path(model_path).glob("*.rbln")
183
- rbln_compiled_models = {cm.stem: rebel.RBLNCompiledModel(cm) for cm in compiled_models}
183
+ expected_compiled_models = [
184
+ Path(model_path) / f"{compiled_model_name}.rbln" for compiled_model_name in expected_compiled_model_names
185
+ ]
186
+ unexpected_compiled_models = [cm for cm in compiled_models if cm not in expected_compiled_models]
187
+ if unexpected_compiled_models:
188
+ # TODO(jongho): fix after May release. raise error if unexpected compiled models are found
189
+ logger.warning(
190
+ f"Unexpected compiled models found: {[cm.name for cm in unexpected_compiled_models]}. "
191
+ f"Please check the model path: {model_path}"
192
+ )
193
+
194
+ rbln_compiled_models = {}
195
+ for compiled_model in expected_compiled_models:
196
+ if not compiled_model.exists():
197
+ raise FileNotFoundError(
198
+ f"Expected RBLN compiled model '{compiled_model.name}' not found at '{model_path}'. "
199
+ "Please ensure all models specified in `rbln_config` are present."
200
+ )
201
+ rbln_compiled_models[compiled_model.stem] = rebel.RBLNCompiledModel(compiled_model)
184
202
  return rbln_compiled_models
185
203
 
186
204
  @classmethod
@@ -271,7 +289,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
271
289
  )
272
290
  config = PretrainedConfig(**config)
273
291
 
274
- rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
292
+ compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
293
+ rbln_compiled_models = cls._load_compiled_models(model_path_subfolder, compiled_model_names)
275
294
 
276
295
  if subfolder != "":
277
296
  model_save_dir = Path(model_path_subfolder).absolute().parent
@@ -34,6 +34,7 @@ _import_structure = {
34
34
  "RBLNAutoModelForCTC",
35
35
  "RBLNAutoModelForDepthEstimation",
36
36
  "RBLNAutoModelForImageClassification",
37
+ "RBLNAutoModelForImageTextToText",
37
38
  "RBLNAutoModelForMaskedLM",
38
39
  "RBLNAutoModelForQuestionAnswering",
39
40
  "RBLNAutoModelForSeq2SeqLM",
@@ -72,6 +73,10 @@ _import_structure = {
72
73
  "RBLNExaoneForCausalLMConfig",
73
74
  "RBLNGemmaForCausalLM",
74
75
  "RBLNGemmaForCausalLMConfig",
76
+ "RBLNGemma3ForCausalLM",
77
+ "RBLNGemma3ForCausalLMConfig",
78
+ "RBLNGemma3ForConditionalGeneration",
79
+ "RBLNGemma3ForConditionalGenerationConfig",
75
80
  "RBLNGPT2LMHeadModel",
76
81
  "RBLNGPT2LMHeadModelConfig",
77
82
  "RBLNIdefics3VisionTransformer",
@@ -148,6 +153,7 @@ if TYPE_CHECKING:
148
153
  RBLNAutoModelForCTC,
149
154
  RBLNAutoModelForDepthEstimation,
150
155
  RBLNAutoModelForImageClassification,
156
+ RBLNAutoModelForImageTextToText,
151
157
  RBLNAutoModelForMaskedLM,
152
158
  RBLNAutoModelForQuestionAnswering,
153
159
  RBLNAutoModelForSeq2SeqLM,
@@ -184,6 +190,10 @@ if TYPE_CHECKING:
184
190
  RBLNDPTForDepthEstimationConfig,
185
191
  RBLNExaoneForCausalLM,
186
192
  RBLNExaoneForCausalLMConfig,
193
+ RBLNGemma3ForCausalLM,
194
+ RBLNGemma3ForCausalLMConfig,
195
+ RBLNGemma3ForConditionalGeneration,
196
+ RBLNGemma3ForConditionalGenerationConfig,
187
197
  RBLNGemmaForCausalLM,
188
198
  RBLNGemmaForCausalLMConfig,
189
199
  RBLNGPT2LMHeadModel,
@@ -36,11 +36,7 @@ from transformers import (
36
36
  )
37
37
  from transformers.modeling_outputs import (
38
38
  BaseModelOutput,
39
- DepthEstimatorOutput,
40
- ImageClassifierOutput,
41
- MaskedLMOutput,
42
39
  QuestionAnsweringModelOutput,
43
- SequenceClassifierOutput,
44
40
  )
45
41
 
46
42
  from ..configuration_utils import RBLNCompileConfig
@@ -63,8 +59,6 @@ class _RBLNTransformerEncoder(RBLNModel):
63
59
  auto_model_class = AutoModel
64
60
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
65
61
  rbln_dtype = "int64"
66
- output_class = BaseModelOutput
67
- output_key = "last_hidden_state"
68
62
 
69
63
  @classmethod
70
64
  def _update_rbln_config(
@@ -149,7 +143,6 @@ class _RBLNImageModel(RBLNModel):
149
143
  auto_model_class = AutoModel
150
144
  main_input_name = "pixel_values"
151
145
  output_class = BaseModelOutput
152
- output_key = "last_hidden_state"
153
146
 
154
147
  @classmethod
155
148
  def _update_rbln_config(
@@ -223,15 +216,11 @@ class RBLNModelForQuestionAnswering(_RBLNTransformerEncoder):
223
216
  class RBLNModelForSequenceClassification(_RBLNTransformerEncoder):
224
217
  auto_model_class = AutoModelForSequenceClassification
225
218
  rbln_model_input_names = ["input_ids", "attention_mask"]
226
- output_class = SequenceClassifierOutput
227
- output_key = "logits"
228
219
 
229
220
 
230
221
  class RBLNModelForMaskedLM(_RBLNTransformerEncoder):
231
222
  auto_model_class = AutoModelForMaskedLM
232
223
  rbln_model_input_names = ["input_ids", "attention_mask"]
233
- output_class = MaskedLMOutput
234
- output_key = "logits"
235
224
 
236
225
 
237
226
  class RBLNModelForTextEncoding(_RBLNTransformerEncoder):
@@ -243,20 +232,14 @@ class RBLNTransformerEncoderForFeatureExtraction(_RBLNTransformerEncoder):
243
232
  # TODO: RBLNModel is also for feature extraction.
244
233
  auto_model_class = AutoModel
245
234
  rbln_model_input_names = ["input_ids", "attention_mask"]
246
- output_class = BaseModelOutput
247
- output_key = "last_hidden_state"
248
235
 
249
236
 
250
237
  class RBLNModelForImageClassification(_RBLNImageModel):
251
238
  auto_model_class = AutoModelForImageClassification
252
- output_class = ImageClassifierOutput
253
- output_key = "logits"
254
239
 
255
240
 
256
241
  class RBLNModelForDepthEstimation(_RBLNImageModel):
257
242
  auto_model_class = AutoModelForDepthEstimation
258
- output_class = DepthEstimatorOutput
259
- output_key = "predicted_depth"
260
243
 
261
244
 
262
245
  class RBLNModelForAudioClassification(RBLNModel):
@@ -273,8 +256,6 @@ class RBLNModelForAudioClassification(RBLNModel):
273
256
  """
274
257
 
275
258
  auto_model_class = AutoModelForAudioClassification
276
- output_class = SequenceClassifierOutput
277
- output_key = "logits"
278
259
 
279
260
  @classmethod
280
261
  def _update_rbln_config(
@@ -31,6 +31,7 @@ _import_structure = {
31
31
  "RBLNAutoModelForSequenceClassification",
32
32
  "RBLNAutoModelForSpeechSeq2Seq",
33
33
  "RBLNAutoModelForVision2Seq",
34
+ "RBLNAutoModelForImageTextToText",
34
35
  ],
35
36
  "bart": [
36
37
  "RBLNBartForConditionalGeneration",
@@ -80,6 +81,12 @@ _import_structure = {
80
81
  ],
81
82
  "exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
82
83
  "gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
84
+ "gemma3": [
85
+ "RBLNGemma3ForCausalLM",
86
+ "RBLNGemma3ForCausalLMConfig",
87
+ "RBLNGemma3ForConditionalGeneration",
88
+ "RBLNGemma3ForConditionalGenerationConfig",
89
+ ],
83
90
  "gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
84
91
  "idefics3": [
85
92
  "RBLNIdefics3VisionTransformer",
@@ -121,6 +128,7 @@ if TYPE_CHECKING:
121
128
  RBLNAutoModelForCTC,
122
129
  RBLNAutoModelForDepthEstimation,
123
130
  RBLNAutoModelForImageClassification,
131
+ RBLNAutoModelForImageTextToText,
124
132
  RBLNAutoModelForMaskedLM,
125
133
  RBLNAutoModelForQuestionAnswering,
126
134
  RBLNAutoModelForSeq2SeqLM,
@@ -170,6 +178,12 @@ if TYPE_CHECKING:
170
178
  )
171
179
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
172
180
  from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
181
+ from .gemma3 import (
182
+ RBLNGemma3ForCausalLM,
183
+ RBLNGemma3ForCausalLMConfig,
184
+ RBLNGemma3ForConditionalGeneration,
185
+ RBLNGemma3ForConditionalGenerationConfig,
186
+ )
173
187
  from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
174
188
  from .idefics3 import (
175
189
  RBLNIdefics3ForConditionalGeneration,
@@ -19,6 +19,7 @@ from .modeling_auto import (
19
19
  RBLNAutoModelForCTC,
20
20
  RBLNAutoModelForDepthEstimation,
21
21
  RBLNAutoModelForImageClassification,
22
+ RBLNAutoModelForImageTextToText,
22
23
  RBLNAutoModelForMaskedLM,
23
24
  RBLNAutoModelForQuestionAnswering,
24
25
  RBLNAutoModelForSeq2SeqLM,
@@ -23,6 +23,8 @@ from transformers.models.auto.modeling_auto import (
23
23
  MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
24
24
  MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
25
25
  MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
26
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
27
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
26
28
  MODEL_FOR_MASKED_LM_MAPPING,
27
29
  MODEL_FOR_MASKED_LM_MAPPING_NAMES,
28
30
  MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@@ -90,6 +92,11 @@ class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
90
92
  _model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
91
93
 
92
94
 
95
+ class RBLNAutoModelForImageTextToText(_BaseAutoModelClass):
96
+ _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
97
+ _model_mapping_names = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
98
+
99
+
93
100
  class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
94
101
  _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
95
102
  _model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
@@ -157,7 +157,11 @@ class DecoderOnlyWrapper(nn.Module):
157
157
  self.config = causal_lm.config
158
158
 
159
159
  if use_rotary_emb:
160
- self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
160
+ rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
161
+ if isinstance(rotary_embs, tuple):
162
+ self.rotary_emb_global, self.rotary_emb_local = rotary_embs
163
+ else:
164
+ self.rotary_emb = rotary_embs
161
165
  else:
162
166
  self.rotary_emb = None
163
167
 
@@ -195,7 +199,10 @@ class DecoderOnlyWrapper(nn.Module):
195
199
  for layer in causal_lm.model.layers:
196
200
  if self.attn_impl == "eager":
197
201
  new_self_attn = DecoderOnlyAttention(
198
- layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
202
+ layer.self_attn,
203
+ self.use_attention_mask,
204
+ self.use_position_ids,
205
+ kvcache_block_size=self.kvcache_block_size,
199
206
  )
200
207
  elif self.attn_impl == "flash_attn":
201
208
  new_self_attn = DecoderOnlyFlashAttention(
@@ -203,6 +210,7 @@ class DecoderOnlyWrapper(nn.Module):
203
210
  kvcache_partition_len=self.kvcache_partition_len,
204
211
  kvcache_block_size=self.kvcache_block_size,
205
212
  use_attention_mask=self.use_attention_mask,
213
+ use_position_ids=self.use_position_ids,
206
214
  )
207
215
  else:
208
216
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -363,6 +371,13 @@ class DecoderOnlyForCausalLM(nn.Module):
363
371
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
364
372
 
365
373
  logits = self.lm_head(hidden_states)
374
+
375
+ # Apply final logit softmaxing if configured, e.g. for Gemma2
376
+ if getattr(self.config, "final_logit_softcapping", None) is not None:
377
+ logits = logits / self.config.final_logit_softcapping
378
+ logits = torch.tanh(logits)
379
+ logits = logits * self.config.final_logit_softcapping
380
+
366
381
  return logits
367
382
 
368
383
 
@@ -610,7 +625,7 @@ class DecoderOnlyAttention(nn.Module):
610
625
  self_attn: Original attention module from the base model
611
626
  """
612
627
 
613
- def __init__(self, self_attn, use_attention_mask, kvcache_block_size):
628
+ def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size):
614
629
  super().__init__()
615
630
  self._original_mod = self_attn
616
631
  self.layer_idx = self_attn.layer_idx
@@ -629,6 +644,7 @@ class DecoderOnlyAttention(nn.Module):
629
644
  self.num_key_value_heads = self.num_heads
630
645
 
631
646
  self.use_attention_mask = use_attention_mask
647
+ self.use_position_ids = use_position_ids
632
648
  self.attention = self.get_attention()
633
649
  self.kvcache_block_size = kvcache_block_size
634
650
  self.__post_init__()
@@ -643,7 +659,9 @@ class DecoderOnlyAttention(nn.Module):
643
659
  self.attention.phase = phase
644
660
 
645
661
  def get_attention(self):
646
- return AttentionOp(self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask)
662
+ return AttentionOp(
663
+ self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
664
+ )
647
665
 
648
666
  def __post_init__(self):
649
667
  self.q_proj = self._original_mod.q_proj
@@ -716,13 +734,16 @@ class DecoderOnlyAttention(nn.Module):
716
734
 
717
735
 
718
736
  class AttentionOp(nn.Module):
719
- def __init__(self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool):
737
+ def __init__(
738
+ self, num_heads: int, head_dim: int, num_key_value_heads: int, use_attention_mask: bool, use_position_ids: bool
739
+ ):
720
740
  super().__init__()
721
741
  self.num_heads = num_heads
722
742
  self.head_dim = head_dim
723
743
  self.num_key_value_heads = num_key_value_heads
724
744
  self.phase = "prefill"
725
745
  self.use_attention_mask = use_attention_mask
746
+ self.use_position_ids = use_position_ids
726
747
 
727
748
  def forward(
728
749
  self,
@@ -755,7 +776,8 @@ class AttentionOp(nn.Module):
755
776
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
756
777
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
757
778
  value_state = value_state.unsqueeze(2)
758
- if self.use_attention_mask:
779
+
780
+ if self.use_attention_mask and not self.use_position_ids:
759
781
  attn_mask = attn_mask.unsqueeze(2)
760
782
 
761
783
  if self.phase == "decode":
@@ -772,7 +794,7 @@ class AttentionOp(nn.Module):
772
794
  )
773
795
 
774
796
  if self.phase == "decode":
775
- if self.use_attention_mask:
797
+ if self.use_attention_mask and not self.use_position_ids:
776
798
  attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
777
799
  q=query_state,
778
800
  k=key_state,
@@ -796,11 +818,11 @@ class AttentionOp(nn.Module):
796
818
  scale=scale,
797
819
  block_table=block_tables,
798
820
  block_size=block_size,
799
- mask=None,
821
+ mask=attn_mask if self.use_position_ids else None,
800
822
  )
801
823
 
802
824
  else:
803
- if self.use_attention_mask:
825
+ if self.use_attention_mask and not self.use_position_ids:
804
826
  attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
805
827
  q=query_state,
806
828
  k=key_state,
@@ -824,8 +846,8 @@ class AttentionOp(nn.Module):
824
846
  scale=scale,
825
847
  block_table=block_tables,
826
848
  block_size=block_size,
827
- is_bidirectional=False,
828
- mask=None,
849
+ is_bidirectional=True if self.phase == "image_prefill" else False, # FIXME, Hard-coded for Gemma3.
850
+ mask=attn_mask if self.use_position_ids else None,
829
851
  )
830
852
 
831
853
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -927,10 +949,13 @@ class RotaryEmbedding(nn.Module):
927
949
 
928
950
 
929
951
  class DecoderOnlyFlashAttention(DecoderOnlyAttention):
930
- def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask):
952
+ def __init__(self, self_attn, kvcache_partition_len, kvcache_block_size, use_attention_mask, use_position_ids):
931
953
  self.kvcache_partition_size = kvcache_partition_len
932
954
  super().__init__(
933
- self_attn=self_attn, use_attention_mask=use_attention_mask, kvcache_block_size=kvcache_block_size
955
+ self_attn=self_attn,
956
+ use_attention_mask=use_attention_mask,
957
+ use_position_ids=use_position_ids,
958
+ kvcache_block_size=kvcache_block_size,
934
959
  )
935
960
 
936
961
  def get_attention(self):
@@ -940,6 +965,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
940
965
  self.num_key_value_heads,
941
966
  self.kvcache_partition_size,
942
967
  self.use_attention_mask,
968
+ self.use_position_ids,
943
969
  )
944
970
 
945
971
  def forward(
@@ -991,12 +1017,14 @@ class FlashAttentionOp(AttentionOp):
991
1017
  num_key_value_heads: int,
992
1018
  kvcache_partition_len: int,
993
1019
  use_attention_mask: bool,
1020
+ use_position_ids: bool,
994
1021
  ):
995
1022
  super().__init__(
996
1023
  num_heads=num_heads,
997
1024
  head_dim=head_dim,
998
1025
  num_key_value_heads=num_key_value_heads,
999
1026
  use_attention_mask=use_attention_mask,
1027
+ use_position_ids=use_position_ids,
1000
1028
  )
1001
1029
  self.kvcache_partition_size = kvcache_partition_len
1002
1030
 
@@ -1016,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
1016
1044
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1017
1045
  key_state = key_state.unsqueeze(2)
1018
1046
  value_state = value_state.unsqueeze(2)
1019
- if self.use_attention_mask:
1047
+ if self.use_attention_mask and not self.use_position_ids:
1020
1048
  attn_mask = attn_mask.unsqueeze(2)
1021
1049
 
1022
1050
  if self.phase == "decode":
@@ -1033,7 +1061,7 @@ class FlashAttentionOp(AttentionOp):
1033
1061
  )
1034
1062
 
1035
1063
  if self.phase == "decode":
1036
- if self.use_attention_mask:
1064
+ if self.use_attention_mask and not self.use_position_ids:
1037
1065
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1038
1066
  q=query_state,
1039
1067
  k=key_state,
@@ -1059,10 +1087,10 @@ class FlashAttentionOp(AttentionOp):
1059
1087
  block_table=block_tables,
1060
1088
  block_size=kvcache_block_size,
1061
1089
  partition=self.kvcache_partition_size,
1062
- mask=None,
1090
+ mask=attn_mask if self.use_position_ids else None,
1063
1091
  )
1064
1092
  else:
1065
- if self.use_attention_mask:
1093
+ if self.use_attention_mask and not self.use_position_ids:
1066
1094
  attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1067
1095
  q=query_state,
1068
1096
  k=key_state,
@@ -1088,8 +1116,8 @@ class FlashAttentionOp(AttentionOp):
1088
1116
  block_table=block_tables,
1089
1117
  block_size=kvcache_block_size,
1090
1118
  partition=self.kvcache_partition_size,
1091
- is_bidirectional=False,
1092
- mask=None,
1119
+ is_bidirectional=True if self.phase == "image_prefill" else False,
1120
+ mask=attn_mask if self.use_position_ids else None,
1093
1121
  )
1094
1122
 
1095
1123
  # reshape for removing repeat_kv
@@ -1098,3 +1126,70 @@ class FlashAttentionOp(AttentionOp):
1098
1126
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1099
1127
 
1100
1128
  return attn_output
1129
+
1130
+
1131
+ class SlidingWindowAttentionOp(AttentionOp):
1132
+ def forward(
1133
+ self,
1134
+ query_state: torch.Tensor,
1135
+ key_state: torch.Tensor,
1136
+ value_state: torch.Tensor,
1137
+ attn_mask: torch.Tensor,
1138
+ past_key_state: torch.Tensor,
1139
+ past_value_state: torch.Tensor,
1140
+ seq_position: Tuple[torch.Tensor],
1141
+ scale: torch.Tensor,
1142
+ block_tables: torch.Tensor,
1143
+ block_size: int,
1144
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1145
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1146
+ key_state = key_state.unsqueeze(2)
1147
+ value_state = value_state.unsqueeze(2)
1148
+
1149
+ if self.phase == "decode":
1150
+ batch_size = key_state.shape[0]
1151
+ else:
1152
+ batch_size = 1
1153
+
1154
+ query_state = query_state.view(
1155
+ batch_size,
1156
+ self.num_key_value_heads,
1157
+ self.num_heads // self.num_key_value_heads,
1158
+ -1, # seq len
1159
+ self.head_dim,
1160
+ )
1161
+
1162
+ if self.phase == "decode":
1163
+ attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_decode(
1164
+ q=query_state,
1165
+ k=key_state,
1166
+ v=value_state,
1167
+ kcache=past_key_state.unsqueeze(2),
1168
+ vcache=past_value_state.unsqueeze(2),
1169
+ cache_seq_len=seq_position[0],
1170
+ cache_offset=seq_position[1],
1171
+ scale=scale,
1172
+ block_table=block_tables,
1173
+ block_size=block_size,
1174
+ )
1175
+ else:
1176
+ attn_output = torch.ops.rbln_custom_ops.paged_sliding_window_attn_prefill(
1177
+ q=query_state,
1178
+ k=key_state,
1179
+ v=value_state,
1180
+ kcache=past_key_state.unsqueeze(2),
1181
+ vcache=past_value_state.unsqueeze(2),
1182
+ cache_seq_len=seq_position[0],
1183
+ cache_offset=seq_position[1],
1184
+ scale=scale,
1185
+ block_table=block_tables,
1186
+ block_size=block_size,
1187
+ is_bidirectional=True if self.phase == "image_prefill" else False,
1188
+ )
1189
+
1190
+ # reshape for removing repeat_kv
1191
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1192
+ attn_output = attn_output.transpose(1, 2).contiguous()
1193
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1194
+
1195
+ return attn_output