optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +230 -67
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +11 -10
  15. optimum/rbln/ops/__init__.py +1 -0
  16. optimum/rbln/ops/attn.py +10 -0
  17. optimum/rbln/ops/flash_attn.py +8 -0
  18. optimum/rbln/ops/moe.py +180 -0
  19. optimum/rbln/ops/sliding_window_attn.py +9 -0
  20. optimum/rbln/transformers/__init__.py +44 -0
  21. optimum/rbln/transformers/modeling_attention_utils.py +124 -222
  22. optimum/rbln/transformers/modeling_outputs.py +25 -0
  23. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  24. optimum/rbln/transformers/models/__init__.py +38 -0
  25. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  27. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  28. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  29. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  30. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  32. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
  33. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  34. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  35. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
  36. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  37. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
  38. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  39. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
  40. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  41. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  42. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  43. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  44. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  45. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  46. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  47. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  48. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  49. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  50. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  51. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  53. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  54. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
  55. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  56. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
  57. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  58. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  59. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  60. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  61. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  62. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  63. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  64. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  65. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  66. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  67. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  68. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  69. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  70. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  71. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  76. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  77. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  78. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  79. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  80. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  81. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
  82. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  83. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  85. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  86. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  87. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  88. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  89. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  90. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  91. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  92. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  94. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  96. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  97. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  98. optimum/rbln/utils/deprecation.py +78 -1
  99. optimum/rbln/utils/hub.py +93 -2
  100. optimum/rbln/utils/import_utils.py +16 -1
  101. optimum/rbln/utils/runtime_utils.py +12 -8
  102. optimum/rbln/utils/submodule.py +24 -0
  103. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
  104. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
  105. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  106. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  107. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  108. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -12,9 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from dataclasses import asdict, dataclass
15
16
  from typing import Any, Dict, List, Literal, Optional, Union, get_args
16
17
 
17
- from ....configuration_utils import RBLNModelConfig
18
+ from ....configuration_utils import RBLNModelConfig, RBLNSerializableConfigProtocol
18
19
  from ....utils.logging import get_logger
19
20
  from ...utils.rbln_quantization import RBLNQuantizationConfig
20
21
  from .configuration_lora import RBLNLoRAConfig
@@ -59,7 +60,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
59
60
  phases: Optional[List[PhaseType]] = None,
60
61
  logits_to_keep: Optional[int] = None,
61
62
  output_hidden_states: Optional[bool] = None,
62
- **kwargs,
63
+ kvcache_metas: Optional[List["KVCacheMeta"]] = None,
64
+ **kwargs: Any,
63
65
  ):
64
66
  """
65
67
  Args:
@@ -93,8 +95,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
93
95
  processing input sequences. Defaults to 128. Must be a positive integer
94
96
  divisible by 64. Affects prefill performance and memory usage.
95
97
  kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
96
- PagedAttention KV cache. See the "KV Cache Number of Blocks (`kvcache_num_blocks`)"
97
- section below for details.
98
+ PagedAttention KV cache at compile time. Defaults to 0 (automatically determined).
99
+ See the "KV Cache Number of Blocks (`kvcache_num_blocks`)" section below for details.
98
100
  decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
99
101
  This allows the model to handle varying batch sizes efficiently during generation. If not specified,
100
102
  defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
@@ -114,6 +116,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
114
116
  logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits.
115
117
  Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used.
116
118
  output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
119
+ kvcache_metas (Optional[List["KVCacheMeta"]]): The metadata for the KV cache tensors. Handled internally if not provided. Defaults to None.
117
120
  kwargs: Additional arguments passed to the parent RBLNModelConfig.
118
121
 
119
122
  Raises:
@@ -152,17 +155,15 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
152
155
 
153
156
 
154
157
  KV Cache Number of Blocks:
155
- `kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache.
156
- Each block holds `kvcache_block_size` tokens of Key and Value states.
157
-
158
- - **Automatic Estimation (Default)**: If `kvcache_num_blocks` is `None`, the system estimates
159
- the maximum number of blocks that can fit into the available RBLN device memory. This
160
- calculation considers the model size (kernel memory), required buffer memory, the number
161
- of layers and heads, `kvcache_block_size`, tensor parallelism, and available RBLN NPU DRAM.
162
- This aims to maximize cache capacity for potentially better performance with long sequences
163
- or larger batches without manual tuning.
164
- - **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
165
- but requires careful consideration of memory limits. Setting it too high may lead to
158
+ `kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache
159
+ at compile time. Each block holds `kvcache_block_size` tokens of Key and Value states.
160
+
161
+ - **Automatic Determination (Default)**: If `kvcache_num_blocks` is `0` (default), the number of blocks
162
+ is automatically determined during compilation to fit within the available DRAM on the NPU. This allows
163
+ the model to utilize the remaining memory after compilation without manual tuning, providing optimal
164
+ cache capacity for better performance with long sequences or larger batches.
165
+ - **Manual Setting**: You can explicitly set the number of blocks to a positive integer. This provides
166
+ finer control but requires careful consideration of memory limits. Setting it too high may lead to
166
167
  compilation errors if it exceeds available memory. The system will issue warnings if your
167
168
  setting exceeds the estimated maximum.
168
169
  - **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
@@ -175,7 +176,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
175
176
  are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
176
177
 
177
178
  The optimal value depends on the specific model, task, hardware, and desired trade-off
178
- between performance and memory usage. The automatic estimation provides a robust starting point.
179
+ between performance and memory usage. Automatic determination (default) provides a robust starting point
180
+ that adapts to the available DRAM on the NPU at compile time.
179
181
  """
180
182
 
181
183
  super().__init__(**kwargs)
@@ -222,7 +224,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
222
224
  if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
223
225
  raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
224
226
 
225
- self.kvcache_num_blocks = kvcache_num_blocks
227
+ self.kvcache_num_blocks = kvcache_num_blocks if kvcache_num_blocks is not None else 0
226
228
  self.cache_impl = cache_impl or "static"
227
229
  self.sliding_window = sliding_window
228
230
  self.sliding_window_layers = sliding_window_layers or []
@@ -257,6 +259,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
257
259
  # Larger batch size should be at the beginning of the list.
258
260
  self.decoder_batch_sizes.sort(reverse=True)
259
261
 
262
+ self.kvcache_metas: List["KVCacheMeta"] = kvcache_metas or []
263
+
260
264
  @staticmethod
261
265
  def validate_phases_type(phases: List[PhaseType]):
262
266
  if not isinstance(phases, list):
@@ -284,12 +288,52 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
284
288
  def can_generate(self) -> bool:
285
289
  return "decode" in self.phases
286
290
 
291
+ @property
292
+ def use_image_prefill(self):
293
+ return "image_prefill" in self.phases
294
+
295
+ @property
296
+ def image_prefill_runtime_idx(self):
297
+ return self.phases.index("image_prefill")
298
+
299
+ @property
300
+ def expected_compiled_model_names(self):
301
+ # ["prefill", "image_prefill", "decoder_batch_1", "decoder_batch_2", ...]
302
+ if self.can_generate:
303
+ return self.phases[: self.decoder_runtime_idx] + [
304
+ f"decoder_batch_{batch_size}" for batch_size in self.decoder_batch_sizes
305
+ ]
306
+ else:
307
+ return self.phases
308
+
309
+ @property
310
+ def decoder_runtime_idx(self):
311
+ if self.can_generate:
312
+ return self.phases.index("decode")
313
+ else:
314
+ raise ValueError("`decode` phase is not in the phases.")
315
+
287
316
  @property
288
317
  def nbits_per_param(self) -> int:
289
318
  if self.quantization:
290
319
  return self.quantization.nbits_per_param
291
320
  return 16
292
321
 
322
+ @property
323
+ def is_auto_num_blocks(self) -> bool:
324
+ """Returns True if kvcache_num_blocks will be automatically determined during compilation to fit within the available DRAM on the NPU."""
325
+ return self.kvcache_num_blocks == 0
326
+
327
+ @property
328
+ def num_full_blocks(self) -> int:
329
+ return (self.max_seq_len // self.kvcache_block_size) * self.batch_size
330
+
331
+ @property
332
+ def num_min_blocks(self) -> int:
333
+ if self.attn_impl == "flash_attn":
334
+ return min(self.max_seq_len // self.kvcache_block_size + 1, self.num_full_blocks)
335
+ return self.batch_size
336
+
293
337
 
294
338
  class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
295
339
  """
@@ -302,3 +346,86 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
302
346
 
303
347
  _default_phases = ["prefill", "decode"]
304
348
  _default_logits_to_keep = 1
349
+
350
+
351
+ @dataclass
352
+ class KVCacheMeta(RBLNSerializableConfigProtocol):
353
+ """
354
+ KVCacheMeta contains metadata describing the key-value (KV) cache tensor for a specific transformer layer.
355
+
356
+ This is used during compilation and runtime on RBLN devices to manage memory and configure the
357
+ static or dynamic characteristics of the cache implementation for decoder-only models.
358
+
359
+ Attributes:
360
+ name (str): Logical name of the KV cache tensor.
361
+ layer_index (int): Index of the transformer layer corresponding to this cache.
362
+ shape (list[int]): The 4D shape of the cache tensor:
363
+ [num_blocks, num_heads, block_size, head_dim]. The number of blocks may be dynamic or static
364
+ depending on model configuration.
365
+ layer_type (str): String describing the attention/cache algorithm (e.g., "full_attention", "sliding_attention").
366
+ is_auto (bool): Whether the number of blocks is automatically determined during compilation (True) or manually specified (False).
367
+ In both cases, the KV cache size is fixed at compile time.
368
+ dtype (str): Data type of the cache buffer ("float16", "float32", etc.).
369
+ """
370
+
371
+ name: str
372
+ layer_index: int
373
+ shape: list[int] # (num_blocks, num_heads, block_size(seq), head_dim)
374
+ layer_type: str
375
+ is_auto: bool
376
+ dtype: str
377
+
378
+ def _prepare_for_serialization(self) -> dict[str, Any]:
379
+ return asdict(self)
380
+
381
+ @property
382
+ def compile_shape(self):
383
+ return [1, self.shape[1], self.shape[2], self.shape[3]] if self.can_resize else self.shape
384
+
385
+ @property
386
+ def can_resize(self):
387
+ return self.is_auto and self.layer_type == "full_attention"
388
+
389
+ @property
390
+ def num_blocks(self) -> int:
391
+ return self.shape[0]
392
+
393
+ @property
394
+ def block_size(self) -> int:
395
+ return self.shape[2]
396
+
397
+ @staticmethod
398
+ def make(
399
+ name: str,
400
+ layer_index: int,
401
+ num_key_value_heads: int,
402
+ head_dim: int,
403
+ dtype: str,
404
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
405
+ ) -> "KVCacheMeta":
406
+ assert len(rbln_config.compile_cfgs) == 0, "KVCacheMeta cannot be created from rbln_config with compile_cfgs"
407
+
408
+ if rbln_config.sliding_window is not None and layer_index in rbln_config.sliding_window_layers:
409
+ layer_type = "sliding_attention"
410
+ block_size = rbln_config.sliding_window
411
+ num_blocks = rbln_config.batch_size
412
+ is_auto = False
413
+
414
+ else:
415
+ layer_type = "full_attention"
416
+ block_size = rbln_config.kvcache_block_size
417
+
418
+ if rbln_config.is_auto_num_blocks:
419
+ num_blocks = rbln_config.num_full_blocks
420
+ is_auto = True
421
+ else:
422
+ num_blocks = rbln_config.kvcache_num_blocks
423
+ is_auto = False
424
+
425
+ shape = [num_blocks, num_key_value_heads, block_size, head_dim]
426
+ if num_blocks <= 0:
427
+ raise ValueError("`num_blocks` must be greater than 0 when using KV cache.")
428
+
429
+ return KVCacheMeta(
430
+ name=name, layer_index=layer_index, shape=shape, layer_type=layer_type, is_auto=is_auto, dtype=dtype
431
+ )
@@ -46,7 +46,7 @@ class RBLNLoRAAdapterConfig(RBLNSerializableConfigProtocol):
46
46
  model = RBLNLlamaForCausalLM.from_pretrained(
47
47
  model_id,
48
48
  rbln_config=RBLNLlamaForCausalLMConfig(lora_config=lora_config, tensor_parallel_size=tp_size, max_seq_len=8192),
49
- torch_dtype="auto",
49
+ dtype="auto",
50
50
  )
51
51
 
52
52
 
@@ -75,7 +75,7 @@ class DecoderOnlyWrapper(nn.Module):
75
75
  f" or equal to max_seq_len({rbln_config.max_seq_len})!"
76
76
  )
77
77
 
78
- self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
78
+ self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len, use_rotary_emb)
79
79
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or self.config.n_layer
80
80
  self._phase = "prefill"
81
81
 
@@ -103,7 +103,7 @@ class DecoderOnlyWrapper(nn.Module):
103
103
  def get_rbln_causal_lm_class(self):
104
104
  return DecoderOnlyForCausalLM
105
105
 
106
- def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
106
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int, use_rotary_emb: bool):
107
107
  new_layers = []
108
108
  for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
109
109
  is_sliding = layer_idx in self.rbln_config.sliding_window_layers
@@ -118,6 +118,7 @@ class DecoderOnlyWrapper(nn.Module):
118
118
  new_layers,
119
119
  self.rbln_config,
120
120
  use_learned_pos_emb=self.__class__._use_learned_pos_emb,
121
+ use_rotary_emb=use_rotary_emb,
121
122
  )
122
123
 
123
124
  if self.is_causal_lm:
@@ -144,8 +145,11 @@ class DecoderOnlyWrapper(nn.Module):
144
145
  local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
145
146
  query_position = (
146
147
  args.pop(0)
147
- # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
148
- if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
148
+ # query_position usage: prefill & (logits_to_keep == 1 or use_local_attention)
149
+ if (
150
+ "prefill" in self.phase
151
+ and (self.rbln_config.logits_to_keep == 1 or self.rbln_config.use_local_attention)
152
+ )
149
153
  else None
150
154
  )
151
155
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
@@ -240,7 +244,6 @@ class DecoderOnlyForCausalLM(nn.Module):
240
244
 
241
245
  Attributes:
242
246
  config: Configuration from the original causal language model
243
- _original_mod: Reference to the original model for components like lm_head
244
247
  model: RBLN-optimized decoder model instance
245
248
  _phase: Current processing phase ("prefill" or "decode")
246
249
  """
@@ -248,10 +251,9 @@ class DecoderOnlyForCausalLM(nn.Module):
248
251
  def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
249
252
  super().__init__()
250
253
  self.config = causal_lm.config
251
- self._original_mod = causal_lm
252
254
  self.model = model
253
255
  self._phase = "prefill"
254
- self.lm_head = self._original_mod.lm_head
256
+ self.lm_head = causal_lm.lm_head
255
257
 
256
258
  @property
257
259
  def phase(self):
@@ -293,7 +295,7 @@ class DecoderOnlyForCausalLM(nn.Module):
293
295
  output_hidden_states=output_hidden_states,
294
296
  )
295
297
 
296
- if "prefill" in self.phase:
298
+ if "prefill" in self.phase and query_position is not None:
297
299
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
298
300
 
299
301
  logits = self.lm_head(hidden_states)
@@ -317,20 +319,35 @@ class DecoderOnlyModel(nn.Module):
317
319
  use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
318
320
 
319
321
  Attributes:
320
- _original_mod: Reference to original Huggingface model
321
322
  layers: ModuleList of RBLN-optimized transformer layers
322
323
  _phase: Current processing phase ("prefill" or "decode")
323
324
  """
324
325
 
326
+ _EMBEDDING_ATTRS = ["embed_tokens", "wte"]
327
+ _POSITION_ATTRS = ["embed_positions", "wpe"]
328
+ _LAYERNORM_ATTRS = ["norm", "final_layer_norm", "final_layernorm", "ln_f", "layer_norm"]
329
+ _PRE_FF_LAYERNORM_ATTRS = None
330
+ _POST_FF_LAYERNORM_ATTRS = None
331
+
325
332
  def __init__(
326
333
  self,
327
334
  model,
328
335
  layers: List["DecoderOnlyLayer"],
329
336
  rbln_config: "RBLNDecoderOnlyModelConfig",
330
337
  use_learned_pos_emb=None,
338
+ use_rotary_emb=True,
331
339
  ):
332
340
  super().__init__()
333
- self._original_mod = model
341
+ self.config = model.config
342
+ # Keep commonly-used original submodules registered on this wrapper so their weights
343
+ # are preserved in state_dict even if the original model object is not kept.
344
+ # Different HF model families use different attribute names; we register what we can
345
+ # and allow subclasses to override getters when needed.
346
+ self.embed_tokens = _get_attr_from_candidates(model, self._EMBEDDING_ATTRS)
347
+ # hasattr(model, "rotary_emb") is workaround for Qwen2VL
348
+ if not (use_rotary_emb or hasattr(model, "rotary_emb")):
349
+ self.embed_positions = _get_attr_from_candidates(model, self._POSITION_ATTRS)
350
+ self.norm = _get_attr_from_candidates(model, self._LAYERNORM_ATTRS)
334
351
  self.layers = nn.ModuleList(layers)
335
352
  self.rbln_config = rbln_config
336
353
  self._phase = "prefill"
@@ -369,26 +386,28 @@ class DecoderOnlyModel(nn.Module):
369
386
  cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
370
387
  return cache_pos_for_partitions
371
388
 
372
- def get_local_cache_positions(self, position_ids, query_position):
373
- max_cache_len = self._original_mod.config.sliding_window
389
+ def get_swa_custom_op_args(self, position_ids, query_position):
390
+ max_cache_len = self.config.sliding_window
374
391
  valid_input_len = 1 if query_position is None else query_position + 1
375
- cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
392
+ cache_seq_len = torch.clamp(position_ids.to(torch.int32), max=max_cache_len)[:, :1] # past seen tokens
376
393
  cache_offset = (
377
394
  torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
378
395
  ) # cache offset for next steps
379
396
 
380
- return cache_seq_len, cache_offset
397
+ # Causal mask for sliding window attention
398
+ attn_mask = torch.arange(max_cache_len)[None, :] - cache_seq_len
399
+ attn_mask = torch.where(attn_mask > 0, 0.0, 1.0)[:, None, None, :]
400
+
401
+ return cache_seq_len, cache_offset, attn_mask
381
402
 
382
403
  def get_last_layernorm(self) -> nn.LayerNorm:
383
- return self._original_mod.norm
404
+ return self.norm
384
405
 
385
406
  def get_embedding(self) -> nn.Embedding:
386
- return self._original_mod.embed_tokens
407
+ return self.embed_tokens
387
408
 
388
409
  def get_pos_embedding(self) -> nn.Embedding:
389
- raise NotImplementedError(
390
- "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
391
- )
410
+ return self.embed_positions
392
411
 
393
412
  def forward(
394
413
  self,
@@ -464,7 +483,8 @@ class DecoderOnlyModel(nn.Module):
464
483
 
465
484
  # Get local cache positions for sliding window layers
466
485
  if len(self.sliding_window_layers) > 0:
467
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
486
+ cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
487
+ sliding_cache_pos = (cache_seq_len, cache_offset)
468
488
 
469
489
  all_hidden_states = () if output_hidden_states else None
470
490
  for layer_idx, layer in enumerate(self.layers):
@@ -472,9 +492,10 @@ class DecoderOnlyModel(nn.Module):
472
492
  all_hidden_states += (hidden_states,)
473
493
 
474
494
  is_sliding = True if layer_idx in self.sliding_window_layers else False
495
+ is_sliding_decode = is_sliding and self.phase == "decode"
475
496
  hidden_states = layer(
476
497
  hidden_states=hidden_states,
477
- attention_mask=attention_mask,
498
+ attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
478
499
  seq_positions=sliding_cache_pos if is_sliding else seq_positions,
479
500
  past_key_values=past_key_values,
480
501
  cos=cos,
@@ -510,14 +531,24 @@ class DecoderOnlyLayer(nn.Module):
510
531
  self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
511
532
 
512
533
  Attributes:
513
- _original_mod: Reference to original layer for accessing components
514
534
  self_attn: Modified attention mechanism mapped to RBLN ops at compile time
515
535
  phase: Current operation phase ("prefill" or "decode")
516
536
  """
517
537
 
538
+ _PRE_ATTN_LAYERNORM = ["input_layernorm", "ln_1", "self_attn_layer_norm", "pre_feedforward_layernorm"]
539
+ _POST_ATTN_LAYERNORM = ["post_attention_layernorm", "ln_2", "final_layer_norm", "post_feedforward_layernorm"]
540
+ _PRE_FF_LAYERNORM_ATTRS = None
541
+ _POST_FF_LAYERNORM_ATTRS = None
542
+ _MLP_ATTR = ("mlp",)
543
+
518
544
  def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
519
545
  super().__init__()
520
- self._original_mod = layer
546
+
547
+ self.pre_attention_layernorm = _get_attr_from_candidates(layer, self._PRE_ATTN_LAYERNORM)
548
+ self.post_attention_layernorm = _get_attr_from_candidates(layer, self._POST_ATTN_LAYERNORM)
549
+ self.pre_feedforward_layernorm = _get_attr_from_candidates(layer, self._PRE_FF_LAYERNORM_ATTRS)
550
+ self.post_feedforward_layernorm = _get_attr_from_candidates(layer, self._POST_FF_LAYERNORM_ATTRS)
551
+ self.mlp = _get_attr_from_candidates(layer, self._MLP_ATTR)
521
552
  self.self_attn = self_attn
522
553
  self._phase = "prefill"
523
554
  self.lora_config = lora_config
@@ -547,13 +578,19 @@ class DecoderOnlyLayer(nn.Module):
547
578
  self.self_attn.phase = phase
548
579
 
549
580
  def get_pre_attention_layernorm(self) -> nn.LayerNorm:
550
- return self._original_mod.input_layernorm
581
+ return self.pre_attention_layernorm
551
582
 
552
583
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
553
- return self._original_mod.post_attention_layernorm
584
+ return self.post_attention_layernorm
585
+
586
+ def get_pre_feedforward_layernorm(self) -> nn.LayerNorm:
587
+ return self.pre_feedforward_layernorm
588
+
589
+ def get_post_feedforward_layernorm(self) -> nn.LayerNorm:
590
+ return self.post_feedforward_layernorm
554
591
 
555
592
  def get_mlp(self) -> nn.Module:
556
- return self._original_mod.mlp
593
+ return self.mlp
557
594
 
558
595
  def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
559
596
  mlp = self.get_mlp()
@@ -619,6 +656,8 @@ class DecoderOnlyAttention(nn.Module):
619
656
  is_sliding: Whether this is sliding window attention
620
657
  """
621
658
 
659
+ _O_PROJ_ATTRS = ["o_proj", "out_proj", "dense"]
660
+
622
661
  def __init__(
623
662
  self,
624
663
  self_attn,
@@ -626,20 +665,18 @@ class DecoderOnlyAttention(nn.Module):
626
665
  is_sliding=False,
627
666
  ):
628
667
  super().__init__()
629
- self._original_mod = self_attn
668
+ self.config = getattr(self_attn, "config", None)
630
669
  self.rbln_config = rbln_config
631
670
  self.layer_idx = self_attn.layer_idx
632
- self.num_heads = (
633
- getattr(self._original_mod, "num_heads", None) or self._original_mod.config.num_attention_heads
634
- )
635
- self.head_dim = self._original_mod.head_dim
671
+ self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
672
+ self.head_dim = self_attn.head_dim
636
673
  self._phase = "prefill"
637
- self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
674
+ self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale(self_attn)))
638
675
 
639
- if hasattr(self._original_mod, "num_key_value_heads"):
640
- self.num_key_value_heads = self._original_mod.num_key_value_heads
641
- elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
642
- self.num_key_value_heads = self._original_mod.config.num_key_value_heads
676
+ if hasattr(self_attn, "num_key_value_heads"):
677
+ self.num_key_value_heads = self_attn.num_key_value_heads
678
+ elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
679
+ self.num_key_value_heads = self_attn.config.num_key_value_heads
643
680
  else:
644
681
  self.num_key_value_heads = self.num_heads
645
682
 
@@ -649,13 +686,16 @@ class DecoderOnlyAttention(nn.Module):
649
686
  self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
650
687
  self.lora_config = rbln_config.lora_config
651
688
 
689
+ if hasattr(self_attn, "sinks"):
690
+ self.sinks = self_attn.sinks.data[:, None]
691
+
652
692
  setattr(self, self.get_attention_name(), self.create_attention_op())
653
- self.__post_init__()
693
+ self.__post_init__(self_attn)
654
694
 
655
695
  def _init_lora_weights(self):
656
696
  """Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
657
697
  for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
658
- original_linear = getattr(self._original_mod, proj_name)
698
+ original_linear = getattr(self, proj_name)
659
699
  lora_linear = LoRALinear(
660
700
  original_linear=original_linear,
661
701
  lora_config=self.lora_config,
@@ -712,16 +752,15 @@ class DecoderOnlyAttention(nn.Module):
712
752
  else:
713
753
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
714
754
 
715
- def __post_init__(self):
755
+ def __post_init__(self, self_attn=None):
756
+ self.q_proj = self_attn.q_proj
757
+ self.k_proj = self_attn.k_proj
758
+ self.v_proj = self_attn.v_proj
759
+ self.o_proj = _get_attr_from_candidates(self_attn, self._O_PROJ_ATTRS)
760
+
716
761
  # Initialize LoRA weights if configured, which will replace linear layers
717
762
  if self.lora_config:
718
763
  self._init_lora_weights()
719
- else:
720
- # Use original linear layers if no LoRA
721
- self.q_proj = self._original_mod.q_proj
722
- self.k_proj = self._original_mod.k_proj
723
- self.v_proj = self._original_mod.v_proj
724
- self.o_proj = self._original_mod.o_proj
725
764
 
726
765
  def projection(
727
766
  self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
@@ -752,8 +791,8 @@ class DecoderOnlyAttention(nn.Module):
752
791
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
753
792
  return apply_rotary_pos_emb(query_states, key_states, cos, sin)
754
793
 
755
- def get_attn_scale(self):
756
- return 1 / math.sqrt(self.head_dim)
794
+ def get_attn_scale(self, self_attn):
795
+ return 1 / math.sqrt(self_attn.head_dim)
757
796
 
758
797
  def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
759
798
  if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
@@ -810,6 +849,7 @@ class DecoderOnlyAttention(nn.Module):
810
849
  block_size=self.kvcache_block_size,
811
850
  k_scale=k_scale,
812
851
  v_scale=v_scale,
852
+ s_aux=getattr(self, "sinks", None),
813
853
  )
814
854
 
815
855
  # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
@@ -882,6 +922,7 @@ class AttentionOp(nn.Module):
882
922
  block_size: int,
883
923
  k_scale: Optional[torch.Tensor] = None,
884
924
  v_scale: Optional[torch.Tensor] = None,
925
+ s_aux: Optional[torch.Tensor] = None,
885
926
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
886
927
  """Compute attention with static shapes and explicit cache management.
887
928
 
@@ -898,6 +939,7 @@ class AttentionOp(nn.Module):
898
939
  block_size: Block size for paged attention
899
940
  k_scale: Scale applied to key
900
941
  v_scale: Scale applied to value
942
+ s_aux: Auxiliary states for attention sinks
901
943
 
902
944
  Returns:
903
945
  Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
@@ -953,6 +995,9 @@ class AttentionOp(nn.Module):
953
995
  op_args["k_scale"] = k_scale
954
996
  op_args["v_scale"] = v_scale
955
997
 
998
+ if s_aux is not None:
999
+ op_args["s_aux"] = s_aux
1000
+
956
1001
  attn_op_name = self.get_attn_op_name()
957
1002
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
958
1003
  if attn_op is None:
@@ -1017,6 +1062,7 @@ class FlashAttentionOp(AttentionOp):
1017
1062
  block_size,
1018
1063
  k_scale=None,
1019
1064
  v_scale=None,
1065
+ s_aux=None,
1020
1066
  ):
1021
1067
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1022
1068
  key_state = key_state.unsqueeze(2)
@@ -1070,6 +1116,9 @@ class FlashAttentionOp(AttentionOp):
1070
1116
  op_args["k_scale"] = k_scale
1071
1117
  op_args["v_scale"] = v_scale
1072
1118
 
1119
+ if s_aux is not None:
1120
+ op_args["s_aux"] = s_aux
1121
+
1073
1122
  attn_op_name = self.get_attn_op_name()
1074
1123
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1075
1124
  if attn_op is None:
@@ -1122,6 +1171,7 @@ class SlidingWindowAttentionOp(AttentionOp):
1122
1171
  block_size: int,
1123
1172
  k_scale: Optional[torch.Tensor] = None,
1124
1173
  v_scale: Optional[torch.Tensor] = None,
1174
+ s_aux: Optional[torch.Tensor] = None,
1125
1175
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1126
1176
  assert self.quantization is None, "Sliding window attention does not support quantization"
1127
1177
  assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
@@ -1165,6 +1215,11 @@ class SlidingWindowAttentionOp(AttentionOp):
1165
1215
  op_args["is_bidirectional"] = True
1166
1216
  else:
1167
1217
  op_args["is_bidirectional"] = False
1218
+ elif self.phase == "decode":
1219
+ op_args["attn_mask"] = attn_mask
1220
+
1221
+ if s_aux is not None:
1222
+ op_args["s_aux"] = s_aux
1168
1223
 
1169
1224
  attn_op_name = self.get_attn_op_name()
1170
1225
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
@@ -1194,7 +1249,7 @@ class RotaryEmbedding(nn.Module):
1194
1249
  else:
1195
1250
  rope_type = "default"
1196
1251
 
1197
- inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1252
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, "cpu", max_seq_len_cached)
1198
1253
  cache_position = torch.arange(0, max_seq_len_cached)
1199
1254
  cache_position_expanded = cache_position[:, None]
1200
1255
 
@@ -1271,3 +1326,22 @@ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tu
1271
1326
  query_states = torch.cat((query_rot, query_pass), dim=-1)
1272
1327
  key_states = torch.cat((key_rot, key_pass), dim=-1)
1273
1328
  return query_states, key_states
1329
+
1330
+
1331
+ def _get_attr_from_candidates(
1332
+ src: object,
1333
+ candidates: Optional[List[str]] = None,
1334
+ ):
1335
+ """
1336
+ Get an attribute from a list of candidate names.
1337
+
1338
+ - If `candidates` is None, this attribute is treated as optional and returns None.
1339
+ - Otherwise, returns `getattr(src, name)` for the first `name` in `candidates` that exists on `src`.
1340
+ - Raises AttributeError if `candidates` is provided but none of the names exist on `src`.
1341
+ """
1342
+ if candidates is None:
1343
+ return None
1344
+ for name in candidates:
1345
+ if hasattr(src, name):
1346
+ return getattr(src, name)
1347
+ raise AttributeError(f"None of the attributes {candidates} exist in {src}")