optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -12,9 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from dataclasses import asdict, dataclass
15
16
  from typing import Any, Dict, List, Literal, Optional, Union, get_args
16
17
 
17
- from ....configuration_utils import RBLNModelConfig
18
+ from ....configuration_utils import RBLNModelConfig, RBLNSerializableConfigProtocol
18
19
  from ....utils.logging import get_logger
19
20
  from ...utils.rbln_quantization import RBLNQuantizationConfig
20
21
  from .configuration_lora import RBLNLoRAConfig
@@ -59,6 +60,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
59
60
  phases: Optional[List[PhaseType]] = None,
60
61
  logits_to_keep: Optional[int] = None,
61
62
  output_hidden_states: Optional[bool] = None,
63
+ kvcache_metas: Optional[List["KVCacheMeta"]] = None,
62
64
  **kwargs,
63
65
  ):
64
66
  """
@@ -93,8 +95,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
93
95
  processing input sequences. Defaults to 128. Must be a positive integer
94
96
  divisible by 64. Affects prefill performance and memory usage.
95
97
  kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
96
- PagedAttention KV cache. See the "KV Cache Number of Blocks (`kvcache_num_blocks`)"
97
- section below for details.
98
+ PagedAttention KV cache at compile time. Defaults to 0 (automatically determined).
99
+ See the "KV Cache Number of Blocks (`kvcache_num_blocks`)" section below for details.
98
100
  decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
99
101
  This allows the model to handle varying batch sizes efficiently during generation. If not specified,
100
102
  defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
@@ -114,6 +116,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
114
116
  logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits.
115
117
  Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used.
116
118
  output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
119
+ kvcache_metas (Optional[List["KVCacheMeta"]]): The metadata for the KV cache tensors. Handled internally if not provided. Defaults to None.
117
120
  kwargs: Additional arguments passed to the parent RBLNModelConfig.
118
121
 
119
122
  Raises:
@@ -152,17 +155,15 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
152
155
 
153
156
 
154
157
  KV Cache Number of Blocks:
155
- `kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache.
156
- Each block holds `kvcache_block_size` tokens of Key and Value states.
157
-
158
- - **Automatic Estimation (Default)**: If `kvcache_num_blocks` is `None`, the system estimates
159
- the maximum number of blocks that can fit into the available RBLN device memory. This
160
- calculation considers the model size (kernel memory), required buffer memory, the number
161
- of layers and heads, `kvcache_block_size`, tensor parallelism, and available RBLN NPU DRAM.
162
- This aims to maximize cache capacity for potentially better performance with long sequences
163
- or larger batches without manual tuning.
164
- - **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
165
- but requires careful consideration of memory limits. Setting it too high may lead to
158
+ `kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache
159
+ at compile time. Each block holds `kvcache_block_size` tokens of Key and Value states.
160
+
161
+ - **Automatic Determination (Default)**: If `kvcache_num_blocks` is `0` (default), the number of blocks
162
+ is automatically determined during compilation to fit within the available DRAM on the NPU. This allows
163
+ the model to utilize the remaining memory after compilation without manual tuning, providing optimal
164
+ cache capacity for better performance with long sequences or larger batches.
165
+ - **Manual Setting**: You can explicitly set the number of blocks to a positive integer. This provides
166
+ finer control but requires careful consideration of memory limits. Setting it too high may lead to
166
167
  compilation errors if it exceeds available memory. The system will issue warnings if your
167
168
  setting exceeds the estimated maximum.
168
169
  - **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
@@ -175,7 +176,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
175
176
  are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
176
177
 
177
178
  The optimal value depends on the specific model, task, hardware, and desired trade-off
178
- between performance and memory usage. The automatic estimation provides a robust starting point.
179
+ between performance and memory usage. Automatic determination (default) provides a robust starting point
180
+ that adapts to the available DRAM on the NPU at compile time.
179
181
  """
180
182
 
181
183
  super().__init__(**kwargs)
@@ -222,7 +224,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
222
224
  if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
223
225
  raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
224
226
 
225
- self.kvcache_num_blocks = kvcache_num_blocks
227
+ self.kvcache_num_blocks = kvcache_num_blocks if kvcache_num_blocks is not None else 0
226
228
  self.cache_impl = cache_impl or "static"
227
229
  self.sliding_window = sliding_window
228
230
  self.sliding_window_layers = sliding_window_layers or []
@@ -257,6 +259,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
257
259
  # Larger batch size should be at the beginning of the list.
258
260
  self.decoder_batch_sizes.sort(reverse=True)
259
261
 
262
+ self.kvcache_metas: List["KVCacheMeta"] = kvcache_metas or []
263
+
260
264
  @staticmethod
261
265
  def validate_phases_type(phases: List[PhaseType]):
262
266
  if not isinstance(phases, list):
@@ -290,6 +294,21 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
290
294
  return self.quantization.nbits_per_param
291
295
  return 16
292
296
 
297
+ @property
298
+ def is_auto_num_blocks(self) -> bool:
299
+ """Returns True if kvcache_num_blocks will be automatically determined during compilation to fit within the available DRAM on the NPU."""
300
+ return self.kvcache_num_blocks == 0
301
+
302
+ @property
303
+ def num_full_blocks(self) -> int:
304
+ return (self.max_seq_len // self.kvcache_block_size) * self.batch_size
305
+
306
+ @property
307
+ def num_min_blocks(self) -> int:
308
+ if self.attn_impl == "flash_attn":
309
+ return min(self.max_seq_len // self.kvcache_block_size + 1, self.num_full_blocks)
310
+ return self.batch_size
311
+
293
312
 
294
313
  class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
295
314
  """
@@ -302,3 +321,86 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
302
321
 
303
322
  _default_phases = ["prefill", "decode"]
304
323
  _default_logits_to_keep = 1
324
+
325
+
326
+ @dataclass
327
+ class KVCacheMeta(RBLNSerializableConfigProtocol):
328
+ """
329
+ KVCacheMeta contains metadata describing the key-value (KV) cache tensor for a specific transformer layer.
330
+
331
+ This is used during compilation and runtime on RBLN devices to manage memory and configure the
332
+ static or dynamic characteristics of the cache implementation for decoder-only models.
333
+
334
+ Attributes:
335
+ name (str): Logical name of the KV cache tensor.
336
+ layer_index (int): Index of the transformer layer corresponding to this cache.
337
+ shape (list[int]): The 4D shape of the cache tensor:
338
+ [num_blocks, num_heads, block_size, head_dim]. The number of blocks may be dynamic or static
339
+ depending on model configuration.
340
+ layer_type (str): String describing the attention/cache algorithm (e.g., "full_attention", "sliding_attention").
341
+ is_auto (bool): Whether the number of blocks is automatically determined during compilation (True) or manually specified (False).
342
+ In both cases, the KV cache size is fixed at compile time.
343
+ dtype (str): Data type of the cache buffer ("float16", "float32", etc.).
344
+ """
345
+
346
+ name: str
347
+ layer_index: int
348
+ shape: list[int] # (num_blocks, num_heads, block_size(seq), head_dim)
349
+ layer_type: str
350
+ is_auto: bool
351
+ dtype: str
352
+
353
+ def _prepare_for_serialization(self) -> dict[str, Any]:
354
+ return asdict(self)
355
+
356
+ @property
357
+ def compile_shape(self):
358
+ return [1, self.shape[1], self.shape[2], self.shape[3]] if self.can_resize else self.shape
359
+
360
+ @property
361
+ def can_resize(self):
362
+ return self.is_auto and self.layer_type == "full_attention"
363
+
364
+ @property
365
+ def num_blocks(self) -> int:
366
+ return self.shape[0]
367
+
368
+ @property
369
+ def block_size(self) -> int:
370
+ return self.shape[2]
371
+
372
+ @staticmethod
373
+ def make(
374
+ name: str,
375
+ layer_index: int,
376
+ num_key_value_heads: int,
377
+ head_dim: int,
378
+ dtype: str,
379
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
380
+ ) -> "KVCacheMeta":
381
+ assert len(rbln_config.compile_cfgs) == 0, "KVCacheMeta cannot be created from rbln_config with compile_cfgs"
382
+
383
+ if rbln_config.sliding_window is not None and layer_index in rbln_config.sliding_window_layers:
384
+ layer_type = "sliding_attention"
385
+ block_size = rbln_config.sliding_window
386
+ num_blocks = rbln_config.batch_size
387
+ is_auto = False
388
+
389
+ else:
390
+ layer_type = "full_attention"
391
+ block_size = rbln_config.kvcache_block_size
392
+
393
+ if rbln_config.is_auto_num_blocks:
394
+ num_blocks = rbln_config.num_full_blocks
395
+ is_auto = True
396
+ else:
397
+ num_blocks = rbln_config.kvcache_num_blocks
398
+ is_auto = False
399
+
400
+ shape = [num_blocks, num_key_value_heads, block_size, head_dim]
401
+ if num_blocks <= 0:
402
+ raise ValueError("`num_blocks` must be greater than 0 when using KV cache.")
403
+
404
+ return KVCacheMeta(
405
+ name=name, layer_index=layer_index, shape=shape, layer_type=layer_type, is_auto=is_auto, dtype=dtype
406
+ )
@@ -46,7 +46,7 @@ class RBLNLoRAAdapterConfig(RBLNSerializableConfigProtocol):
46
46
  model = RBLNLlamaForCausalLM.from_pretrained(
47
47
  model_id,
48
48
  rbln_config=RBLNLlamaForCausalLMConfig(lora_config=lora_config, tensor_parallel_size=tp_size, max_seq_len=8192),
49
- torch_dtype="auto",
49
+ dtype="auto",
50
50
  )
51
51
 
52
52
 
@@ -75,7 +75,7 @@ class DecoderOnlyWrapper(nn.Module):
75
75
  f" or equal to max_seq_len({rbln_config.max_seq_len})!"
76
76
  )
77
77
 
78
- self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
78
+ self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len, use_rotary_emb)
79
79
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or self.config.n_layer
80
80
  self._phase = "prefill"
81
81
 
@@ -103,7 +103,7 @@ class DecoderOnlyWrapper(nn.Module):
103
103
  def get_rbln_causal_lm_class(self):
104
104
  return DecoderOnlyForCausalLM
105
105
 
106
- def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
106
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int, use_rotary_emb: bool):
107
107
  new_layers = []
108
108
  for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
109
109
  is_sliding = layer_idx in self.rbln_config.sliding_window_layers
@@ -118,6 +118,7 @@ class DecoderOnlyWrapper(nn.Module):
118
118
  new_layers,
119
119
  self.rbln_config,
120
120
  use_learned_pos_emb=self.__class__._use_learned_pos_emb,
121
+ use_rotary_emb=use_rotary_emb,
121
122
  )
122
123
 
123
124
  if self.is_causal_lm:
@@ -144,8 +145,11 @@ class DecoderOnlyWrapper(nn.Module):
144
145
  local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
145
146
  query_position = (
146
147
  args.pop(0)
147
- # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
148
- if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
148
+ # query_position usage: prefill & (logits_to_keep == 1 or use_local_attention)
149
+ if (
150
+ "prefill" in self.phase
151
+ and (self.rbln_config.logits_to_keep == 1 or self.rbln_config.use_local_attention)
152
+ )
149
153
  else None
150
154
  )
151
155
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
@@ -240,7 +244,6 @@ class DecoderOnlyForCausalLM(nn.Module):
240
244
 
241
245
  Attributes:
242
246
  config: Configuration from the original causal language model
243
- _original_mod: Reference to the original model for components like lm_head
244
247
  model: RBLN-optimized decoder model instance
245
248
  _phase: Current processing phase ("prefill" or "decode")
246
249
  """
@@ -248,10 +251,9 @@ class DecoderOnlyForCausalLM(nn.Module):
248
251
  def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
249
252
  super().__init__()
250
253
  self.config = causal_lm.config
251
- self._original_mod = causal_lm
252
254
  self.model = model
253
255
  self._phase = "prefill"
254
- self.lm_head = self._original_mod.lm_head
256
+ self.lm_head = causal_lm.lm_head
255
257
 
256
258
  @property
257
259
  def phase(self):
@@ -293,7 +295,7 @@ class DecoderOnlyForCausalLM(nn.Module):
293
295
  output_hidden_states=output_hidden_states,
294
296
  )
295
297
 
296
- if "prefill" in self.phase:
298
+ if "prefill" in self.phase and query_position is not None:
297
299
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
298
300
 
299
301
  logits = self.lm_head(hidden_states)
@@ -317,20 +319,35 @@ class DecoderOnlyModel(nn.Module):
317
319
  use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
318
320
 
319
321
  Attributes:
320
- _original_mod: Reference to original Huggingface model
321
322
  layers: ModuleList of RBLN-optimized transformer layers
322
323
  _phase: Current processing phase ("prefill" or "decode")
323
324
  """
324
325
 
326
+ _EMBEDDING_ATTRS = ["embed_tokens", "wte"]
327
+ _POSITION_ATTRS = ["embed_positions", "wpe"]
328
+ _LAYERNORM_ATTRS = ["norm", "final_layer_norm", "final_layernorm", "ln_f", "layer_norm"]
329
+ _PRE_FF_LAYERNORM_ATTRS = None
330
+ _POST_FF_LAYERNORM_ATTRS = None
331
+
325
332
  def __init__(
326
333
  self,
327
334
  model,
328
335
  layers: List["DecoderOnlyLayer"],
329
336
  rbln_config: "RBLNDecoderOnlyModelConfig",
330
337
  use_learned_pos_emb=None,
338
+ use_rotary_emb=True,
331
339
  ):
332
340
  super().__init__()
333
- self._original_mod = model
341
+ self.config = model.config
342
+ # Keep commonly-used original submodules registered on this wrapper so their weights
343
+ # are preserved in state_dict even if the original model object is not kept.
344
+ # Different HF model families use different attribute names; we register what we can
345
+ # and allow subclasses to override getters when needed.
346
+ self.embed_tokens = _get_attr_from_candidates(model, self._EMBEDDING_ATTRS)
347
+ # hasattr(model, "rotary_emb") is workaround for Qwen2VL
348
+ if not (use_rotary_emb or hasattr(model, "rotary_emb")):
349
+ self.embed_positions = _get_attr_from_candidates(model, self._POSITION_ATTRS)
350
+ self.norm = _get_attr_from_candidates(model, self._LAYERNORM_ATTRS)
334
351
  self.layers = nn.ModuleList(layers)
335
352
  self.rbln_config = rbln_config
336
353
  self._phase = "prefill"
@@ -369,26 +386,28 @@ class DecoderOnlyModel(nn.Module):
369
386
  cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
370
387
  return cache_pos_for_partitions
371
388
 
372
- def get_local_cache_positions(self, position_ids, query_position):
373
- max_cache_len = self._original_mod.config.sliding_window
389
+ def get_swa_custom_op_args(self, position_ids, query_position):
390
+ max_cache_len = self.config.sliding_window
374
391
  valid_input_len = 1 if query_position is None else query_position + 1
375
- cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
392
+ cache_seq_len = torch.clamp(position_ids.to(torch.int32), max=max_cache_len)[:, :1] # past seen tokens
376
393
  cache_offset = (
377
394
  torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
378
395
  ) # cache offset for next steps
379
396
 
380
- return cache_seq_len, cache_offset
397
+ # Causal mask for sliding window attention
398
+ attn_mask = torch.arange(max_cache_len)[None, :] - cache_seq_len
399
+ attn_mask = torch.where(attn_mask > 0, 0.0, 1.0)[:, None, None, :]
400
+
401
+ return cache_seq_len, cache_offset, attn_mask
381
402
 
382
403
  def get_last_layernorm(self) -> nn.LayerNorm:
383
- return self._original_mod.norm
404
+ return self.norm
384
405
 
385
406
  def get_embedding(self) -> nn.Embedding:
386
- return self._original_mod.embed_tokens
407
+ return self.embed_tokens
387
408
 
388
409
  def get_pos_embedding(self) -> nn.Embedding:
389
- raise NotImplementedError(
390
- "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
391
- )
410
+ return self.embed_positions
392
411
 
393
412
  def forward(
394
413
  self,
@@ -464,7 +483,8 @@ class DecoderOnlyModel(nn.Module):
464
483
 
465
484
  # Get local cache positions for sliding window layers
466
485
  if len(self.sliding_window_layers) > 0:
467
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
486
+ cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
487
+ sliding_cache_pos = (cache_seq_len, cache_offset)
468
488
 
469
489
  all_hidden_states = () if output_hidden_states else None
470
490
  for layer_idx, layer in enumerate(self.layers):
@@ -472,9 +492,10 @@ class DecoderOnlyModel(nn.Module):
472
492
  all_hidden_states += (hidden_states,)
473
493
 
474
494
  is_sliding = True if layer_idx in self.sliding_window_layers else False
495
+ is_sliding_decode = is_sliding and self.phase == "decode"
475
496
  hidden_states = layer(
476
497
  hidden_states=hidden_states,
477
- attention_mask=attention_mask,
498
+ attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
478
499
  seq_positions=sliding_cache_pos if is_sliding else seq_positions,
479
500
  past_key_values=past_key_values,
480
501
  cos=cos,
@@ -510,14 +531,23 @@ class DecoderOnlyLayer(nn.Module):
510
531
  self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
511
532
 
512
533
  Attributes:
513
- _original_mod: Reference to original layer for accessing components
514
534
  self_attn: Modified attention mechanism mapped to RBLN ops at compile time
515
535
  phase: Current operation phase ("prefill" or "decode")
516
536
  """
517
537
 
538
+ _PRE_ATTN_LAYERNORM = ["input_layernorm", "ln_1", "self_attn_layer_norm", "pre_feedforward_layernorm"]
539
+ _POST_ATTN_LAYERNORM = ["post_attention_layernorm", "ln_2", "final_layer_norm", "post_feedforward_layernorm"]
540
+ _PRE_FF_LAYERNORM_ATTRS = None
541
+ _POST_FF_LAYERNORM_ATTRS = None
542
+
518
543
  def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
519
544
  super().__init__()
520
- self._original_mod = layer
545
+
546
+ self.pre_attention_layernorm = _get_attr_from_candidates(layer, self._PRE_ATTN_LAYERNORM)
547
+ self.post_attention_layernorm = _get_attr_from_candidates(layer, self._POST_ATTN_LAYERNORM)
548
+ self.pre_feedforward_layernorm = _get_attr_from_candidates(layer, self._PRE_FF_LAYERNORM_ATTRS)
549
+ self.post_feedforward_layernorm = _get_attr_from_candidates(layer, self._POST_FF_LAYERNORM_ATTRS)
550
+ self.mlp = layer.mlp
521
551
  self.self_attn = self_attn
522
552
  self._phase = "prefill"
523
553
  self.lora_config = lora_config
@@ -547,13 +577,19 @@ class DecoderOnlyLayer(nn.Module):
547
577
  self.self_attn.phase = phase
548
578
 
549
579
  def get_pre_attention_layernorm(self) -> nn.LayerNorm:
550
- return self._original_mod.input_layernorm
580
+ return self.pre_attention_layernorm
551
581
 
552
582
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
553
- return self._original_mod.post_attention_layernorm
583
+ return self.post_attention_layernorm
584
+
585
+ def get_pre_feedforward_layernorm(self) -> nn.LayerNorm:
586
+ return self.pre_feedforward_layernorm
587
+
588
+ def get_post_feedforward_layernorm(self) -> nn.LayerNorm:
589
+ return self.post_feedforward_layernorm
554
590
 
555
591
  def get_mlp(self) -> nn.Module:
556
- return self._original_mod.mlp
592
+ return self.mlp
557
593
 
558
594
  def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
559
595
  mlp = self.get_mlp()
@@ -619,6 +655,8 @@ class DecoderOnlyAttention(nn.Module):
619
655
  is_sliding: Whether this is sliding window attention
620
656
  """
621
657
 
658
+ _O_PROJ_ATTRS = ["o_proj", "out_proj", "dense"]
659
+
622
660
  def __init__(
623
661
  self,
624
662
  self_attn,
@@ -626,20 +664,18 @@ class DecoderOnlyAttention(nn.Module):
626
664
  is_sliding=False,
627
665
  ):
628
666
  super().__init__()
629
- self._original_mod = self_attn
667
+ self.config = getattr(self_attn, "config", None)
630
668
  self.rbln_config = rbln_config
631
669
  self.layer_idx = self_attn.layer_idx
632
- self.num_heads = (
633
- getattr(self._original_mod, "num_heads", None) or self._original_mod.config.num_attention_heads
634
- )
635
- self.head_dim = self._original_mod.head_dim
670
+ self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
671
+ self.head_dim = self_attn.head_dim
636
672
  self._phase = "prefill"
637
- self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
673
+ self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale(self_attn)))
638
674
 
639
- if hasattr(self._original_mod, "num_key_value_heads"):
640
- self.num_key_value_heads = self._original_mod.num_key_value_heads
641
- elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
642
- self.num_key_value_heads = self._original_mod.config.num_key_value_heads
675
+ if hasattr(self_attn, "num_key_value_heads"):
676
+ self.num_key_value_heads = self_attn.num_key_value_heads
677
+ elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
678
+ self.num_key_value_heads = self_attn.config.num_key_value_heads
643
679
  else:
644
680
  self.num_key_value_heads = self.num_heads
645
681
 
@@ -649,13 +685,16 @@ class DecoderOnlyAttention(nn.Module):
649
685
  self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
650
686
  self.lora_config = rbln_config.lora_config
651
687
 
688
+ if hasattr(self_attn, "sinks"):
689
+ self.sinks = self_attn.sinks.data[:, None]
690
+
652
691
  setattr(self, self.get_attention_name(), self.create_attention_op())
653
- self.__post_init__()
692
+ self.__post_init__(self_attn)
654
693
 
655
694
  def _init_lora_weights(self):
656
695
  """Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
657
696
  for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
658
- original_linear = getattr(self._original_mod, proj_name)
697
+ original_linear = getattr(self, proj_name)
659
698
  lora_linear = LoRALinear(
660
699
  original_linear=original_linear,
661
700
  lora_config=self.lora_config,
@@ -712,16 +751,15 @@ class DecoderOnlyAttention(nn.Module):
712
751
  else:
713
752
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
714
753
 
715
- def __post_init__(self):
754
+ def __post_init__(self, self_attn=None):
755
+ self.q_proj = self_attn.q_proj
756
+ self.k_proj = self_attn.k_proj
757
+ self.v_proj = self_attn.v_proj
758
+ self.o_proj = _get_attr_from_candidates(self_attn, self._O_PROJ_ATTRS)
759
+
716
760
  # Initialize LoRA weights if configured, which will replace linear layers
717
761
  if self.lora_config:
718
762
  self._init_lora_weights()
719
- else:
720
- # Use original linear layers if no LoRA
721
- self.q_proj = self._original_mod.q_proj
722
- self.k_proj = self._original_mod.k_proj
723
- self.v_proj = self._original_mod.v_proj
724
- self.o_proj = self._original_mod.o_proj
725
763
 
726
764
  def projection(
727
765
  self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
@@ -752,8 +790,8 @@ class DecoderOnlyAttention(nn.Module):
752
790
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
753
791
  return apply_rotary_pos_emb(query_states, key_states, cos, sin)
754
792
 
755
- def get_attn_scale(self):
756
- return 1 / math.sqrt(self.head_dim)
793
+ def get_attn_scale(self, self_attn):
794
+ return 1 / math.sqrt(self_attn.head_dim)
757
795
 
758
796
  def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
759
797
  if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
@@ -810,6 +848,7 @@ class DecoderOnlyAttention(nn.Module):
810
848
  block_size=self.kvcache_block_size,
811
849
  k_scale=k_scale,
812
850
  v_scale=v_scale,
851
+ s_aux=getattr(self, "sinks", None),
813
852
  )
814
853
 
815
854
  # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
@@ -882,6 +921,7 @@ class AttentionOp(nn.Module):
882
921
  block_size: int,
883
922
  k_scale: Optional[torch.Tensor] = None,
884
923
  v_scale: Optional[torch.Tensor] = None,
924
+ s_aux: Optional[torch.Tensor] = None,
885
925
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
886
926
  """Compute attention with static shapes and explicit cache management.
887
927
 
@@ -898,6 +938,7 @@ class AttentionOp(nn.Module):
898
938
  block_size: Block size for paged attention
899
939
  k_scale: Scale applied to key
900
940
  v_scale: Scale applied to value
941
+ s_aux: Auxiliary states for attention sinks
901
942
 
902
943
  Returns:
903
944
  Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
@@ -953,6 +994,9 @@ class AttentionOp(nn.Module):
953
994
  op_args["k_scale"] = k_scale
954
995
  op_args["v_scale"] = v_scale
955
996
 
997
+ if s_aux is not None:
998
+ op_args["s_aux"] = s_aux
999
+
956
1000
  attn_op_name = self.get_attn_op_name()
957
1001
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
958
1002
  if attn_op is None:
@@ -1017,6 +1061,7 @@ class FlashAttentionOp(AttentionOp):
1017
1061
  block_size,
1018
1062
  k_scale=None,
1019
1063
  v_scale=None,
1064
+ s_aux=None,
1020
1065
  ):
1021
1066
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1022
1067
  key_state = key_state.unsqueeze(2)
@@ -1070,6 +1115,9 @@ class FlashAttentionOp(AttentionOp):
1070
1115
  op_args["k_scale"] = k_scale
1071
1116
  op_args["v_scale"] = v_scale
1072
1117
 
1118
+ if s_aux is not None:
1119
+ op_args["s_aux"] = s_aux
1120
+
1073
1121
  attn_op_name = self.get_attn_op_name()
1074
1122
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1075
1123
  if attn_op is None:
@@ -1122,6 +1170,7 @@ class SlidingWindowAttentionOp(AttentionOp):
1122
1170
  block_size: int,
1123
1171
  k_scale: Optional[torch.Tensor] = None,
1124
1172
  v_scale: Optional[torch.Tensor] = None,
1173
+ s_aux: Optional[torch.Tensor] = None,
1125
1174
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1126
1175
  assert self.quantization is None, "Sliding window attention does not support quantization"
1127
1176
  assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
@@ -1165,6 +1214,11 @@ class SlidingWindowAttentionOp(AttentionOp):
1165
1214
  op_args["is_bidirectional"] = True
1166
1215
  else:
1167
1216
  op_args["is_bidirectional"] = False
1217
+ elif self.phase == "decode":
1218
+ op_args["attn_mask"] = attn_mask
1219
+
1220
+ if s_aux is not None:
1221
+ op_args["s_aux"] = s_aux
1168
1222
 
1169
1223
  attn_op_name = self.get_attn_op_name()
1170
1224
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
@@ -1194,7 +1248,7 @@ class RotaryEmbedding(nn.Module):
1194
1248
  else:
1195
1249
  rope_type = "default"
1196
1250
 
1197
- inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1251
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, "cpu", max_seq_len_cached)
1198
1252
  cache_position = torch.arange(0, max_seq_len_cached)
1199
1253
  cache_position_expanded = cache_position[:, None]
1200
1254
 
@@ -1271,3 +1325,22 @@ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tu
1271
1325
  query_states = torch.cat((query_rot, query_pass), dim=-1)
1272
1326
  key_states = torch.cat((key_rot, key_pass), dim=-1)
1273
1327
  return query_states, key_states
1328
+
1329
+
1330
+ def _get_attr_from_candidates(
1331
+ src: object,
1332
+ candidates: Optional[List[str]] = None,
1333
+ ):
1334
+ """
1335
+ Get an attribute from a list of candidate names.
1336
+
1337
+ - If `candidates` is None, this attribute is treated as optional and returns None.
1338
+ - Otherwise, returns `getattr(src, name)` for the first `name` in `candidates` that exists on `src`.
1339
+ - Raises AttributeError if `candidates` is provided but none of the names exist on `src`.
1340
+ """
1341
+ if candidates is None:
1342
+ return None
1343
+ for name in candidates:
1344
+ if hasattr(src, name):
1345
+ return getattr(src, name)
1346
+ raise AttributeError(f"None of the attributes {candidates} exist in {src}")
@@ -177,7 +177,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
177
177
  dec_attn_mask: torch.Tensor,
178
178
  page_table_manager: RBLNPageTableManager,
179
179
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
180
- config: "PreTrainedConfig" = None,
180
+ config: Optional["PreTrainedConfig"] = None,
181
181
  logits_last_dim: Optional[int] = None,
182
182
  **kwargs: Any,
183
183
  ) -> None:
@@ -391,16 +391,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
391
391
  # Initialize attention mask for chunked processing
392
392
  if self.rbln_config.use_attention_mask:
393
393
  if self.rbln_config.use_position_ids:
394
- chunked_attention_mask = torch.zeros(
395
- 1, self.rbln_config.max_seq_len, dtype=self.rbln_config.torch_dtype
396
- )
394
+ chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=self.rbln_config.dtype)
397
395
  else:
398
396
  chunked_attention_mask = torch.zeros(
399
397
  1,
400
398
  1,
401
399
  self.rbln_config.prefill_chunk_size,
402
400
  self.rbln_config.max_seq_len,
403
- dtype=self.rbln_config.torch_dtype,
401
+ dtype=self.rbln_config.dtype,
404
402
  )
405
403
  else:
406
404
  chunked_attention_mask = None
@@ -467,7 +465,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
467
465
  1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
468
466
  logits_last_dim,
469
467
  )
470
- output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
468
+ output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
471
469
 
472
470
  if self.rbln_config.logits_to_keep == 1:
473
471
  for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
@@ -486,7 +484,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
486
484
  self.config.hidden_size,
487
485
  )
488
486
  output_hidden_states = [
489
- torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
487
+ torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
490
488
  for _ in range(self.config.num_hidden_layers + 1)
491
489
  ]
492
490