optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -36,6 +36,7 @@ from .decoderonly_architecture import (
36
36
  DecoderOnlyWrapper,
37
37
  set_default_values,
38
38
  validate_attention_method,
39
+ validate_sliding_window_size,
39
40
  )
40
41
 
41
42
 
@@ -56,39 +57,28 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
56
57
  dec_attn_mask: torch.Tensor,
57
58
  block_tables: torch.Tensor,
58
59
  free_block_pool: Deque,
59
- kvcache_block_size: int,
60
- use_attention_mask: bool,
61
- attn_impl: str,
62
- use_position_ids: bool,
60
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
63
61
  **kwargs: Any,
64
62
  ) -> None:
65
63
  super().__init__(runtime, **kwargs)
66
64
  self.phase = phase
67
65
  self.batch_size = batch_size
68
-
69
- # shared data structures between prefill and decode phase
70
- self.use_attention_mask = use_attention_mask
66
+ self.rbln_config = rbln_config
71
67
 
72
68
  # shared tensor between prefill and decode phase
73
69
  self.dec_attn_mask = dec_attn_mask
74
70
  self.block_tables = block_tables
75
71
  self.free_block_pool = free_block_pool
76
- self.use_position_ids = use_position_ids
77
72
 
78
- self.kvcache_block_size = kvcache_block_size
79
73
  self.empty_block = -1
80
- self.attn_impl = attn_impl
81
-
82
74
  if self.phase == "prefill":
83
75
  vocab_size = kwargs.pop("vocab_size")
84
- self.max_seq_len = kwargs.pop("max_seq_len")
85
- self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
86
76
  self.output_size = [1, 1, vocab_size]
87
77
  self.causal_mask = 1 - torch.triu(
88
- torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
78
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
89
79
  )
90
80
 
91
- def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
81
+ def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
92
82
  """
93
83
  Manages and returns the KV cache block tables.
94
84
  Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
@@ -98,7 +88,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
98
88
  batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
99
89
 
100
90
  Returns:
101
- torch.Tensor: Updated block tables.
91
+ Updated block tables.
102
92
  """
103
93
 
104
94
  NO_BLOCKS_ERROR = (
@@ -131,31 +121,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
131
121
  else:
132
122
  raise RuntimeError(NO_BLOCKS_ERROR)
133
123
 
134
- if self.phase == "prefill":
135
- # Track previously used blocks and return them to the free_block_pool and
136
- # reset the current batch's block table to empty blocks
137
- prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
138
- self.free_block_pool.extend(prev_blocks)
139
- self.block_tables[batch_idx].fill_(self.empty_block)
140
-
141
- # Get the start (s) and end (e) positions from cache_position and
142
- # iterate over the cache positions to allocate necessary blocks
143
- s, e = cache_position[0][0].item(), cache_position[0][-1].item()
144
- for position in range(s, e + 1, self.kvcache_block_size):
145
- block_idx = position // self.kvcache_block_size
146
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
147
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
148
- update_block(batch_idx, block_idx)
149
-
150
- return replace_empty_block(self.block_tables[batch_idx])
151
- # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
152
- else:
153
- for b_idx in range(self.batch_size):
154
- position = cache_position[b_idx][0].item()
155
- block_idx = position // self.kvcache_block_size
156
- update_block(b_idx, block_idx)
124
+ def get_global_block_tables(batch_idx: int):
125
+ if self.rbln_config.cache_impl == "sliding_window":
126
+ return None
127
+
128
+ if self.phase == "prefill":
129
+ # Track previously used blocks and return them to the free_block_pool and
130
+ # reset the current batch's block table to empty blocks
131
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
132
+ self.free_block_pool.extend(prev_blocks)
133
+ self.block_tables[batch_idx].fill_(self.empty_block)
134
+
135
+ # Get the start (s) and end (e) positions from cache_position and
136
+ # iterate over the cache positions to allocate necessary blocks
137
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
138
+ for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
139
+ block_idx = position // self.rbln_config.kvcache_block_size
140
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
141
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
142
+ update_block(batch_idx, block_idx)
143
+
144
+ return replace_empty_block(self.block_tables[batch_idx])
145
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
146
+ else:
147
+ for b_idx in range(self.batch_size):
148
+ position = cache_position[b_idx][0].item()
149
+ block_idx = position // self.rbln_config.kvcache_block_size
150
+ update_block(b_idx, block_idx)
157
151
 
158
- return replace_empty_block(self.block_tables)
152
+ return replace_empty_block(self.block_tables)
153
+
154
+ def get_local_block_tables(batch_idx: int):
155
+ if self.rbln_config.cache_impl == "static":
156
+ return None
157
+ else:
158
+ return (
159
+ torch.tensor([batch_idx], dtype=torch.int16)
160
+ if self.phase == "prefill"
161
+ else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
162
+ )
163
+
164
+ return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
165
+
166
+ def is_external_block_tables(
167
+ self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
168
+ ):
169
+ if self.rbln_config.cache_impl == "static" and block_tables is None:
170
+ return False
171
+ elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
172
+ return False
173
+ elif self.rbln_config.cache_impl == "hybrid":
174
+ if (block_tables is not None) != (local_block_tables is not None):
175
+ raise ValueError(
176
+ "Both block_tables and local_block_tables must be provided or neither of them must be provided."
177
+ )
178
+ elif block_tables is None and local_block_tables is None:
179
+ return False
180
+
181
+ return True
159
182
 
160
183
  def forward(
161
184
  self,
@@ -180,11 +203,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
180
203
  else:
181
204
  inputs = inputs_embeds
182
205
 
183
- if block_tables is None:
184
- block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
185
- is_external_block_tables = False
186
- else:
187
- is_external_block_tables = True
206
+ is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
207
+ if not is_external_block_tables:
208
+ block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
188
209
 
189
210
  if self.phase == "decode":
190
211
  return self.decode_forward(
@@ -204,6 +225,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
204
225
  attention_mask,
205
226
  batch_idx,
206
227
  block_tables,
228
+ is_external_block_tables=is_external_block_tables,
207
229
  position_embed=position_embed,
208
230
  token_type_ids=token_type_ids,
209
231
  local_block_tables=local_block_tables,
@@ -229,7 +251,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
229
251
  if batch_size != cache_position.shape[0]:
230
252
  raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
231
253
 
232
- if self.use_attention_mask and attention_mask is None:
254
+ if self.rbln_config.use_attention_mask and attention_mask is None:
233
255
  for b_idx in range(batch_size):
234
256
  decoding_step = cache_position[b_idx].item()
235
257
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
@@ -245,7 +267,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
245
267
 
246
268
  attention_mask = self.dec_attn_mask
247
269
 
248
- if self.batch_size < block_tables.shape[0]:
270
+ if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
249
271
  block_tables = block_tables[: self.batch_size]
250
272
 
251
273
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
@@ -255,9 +277,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
255
277
  inputs,
256
278
  cache_position,
257
279
  block_tables,
280
+ local_block_tables,
258
281
  position_embed,
259
- attention_mask if self.use_attention_mask else None,
260
- position_ids if self.use_position_ids else None,
282
+ attention_mask if self.rbln_config.use_attention_mask else None,
283
+ position_ids if self.rbln_config.use_position_ids else None,
261
284
  )
262
285
 
263
286
  return RBLNDecoderOnlyOutput(logits=logits)
@@ -268,7 +291,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
268
291
  cache_position: torch.Tensor,
269
292
  attention_mask: Optional[torch.Tensor] = None,
270
293
  position_embed: Optional[torch.Tensor] = None,
271
- local_block_tables: Optional[torch.Tensor] = None,
272
294
  token_type_ids: Optional[torch.Tensor] = None,
273
295
  ):
274
296
  """
@@ -283,15 +305,15 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
283
305
  )
284
306
 
285
307
  query_length = inputs.shape[1]
286
- if query_length > self.max_seq_len:
308
+ if query_length > self.rbln_config.max_seq_len:
287
309
  raise ValueError(
288
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
310
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
289
311
  )
290
312
 
291
313
  # Initialize attention mask for chunked processing
292
314
  chunked_attention_mask = (
293
- torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
294
- if self.use_attention_mask
315
+ torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
316
+ if self.rbln_config.use_attention_mask
295
317
  else None
296
318
  )
297
319
 
@@ -305,8 +327,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
305
327
  ]
306
328
 
307
329
  # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
308
- if query_length % self.prefill_chunk_size != 0:
309
- padding_size = (self.prefill_chunk_size - query_length) % self.prefill_chunk_size
330
+ padding_size = 0
331
+ if query_length % self.rbln_config.prefill_chunk_size != 0:
332
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
310
333
  # inputs_embeds
311
334
  if inputs.dim() == 3:
312
335
  inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
@@ -351,10 +374,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
351
374
  attention_mask: Optional[torch.Tensor] = None,
352
375
  batch_idx: int = None,
353
376
  block_tables: torch.Tensor = None,
354
- is_external_block_tables: bool = None,
377
+ is_external_block_tables: bool = False,
355
378
  position_embed: Optional[torch.Tensor] = None,
356
- local_block_tables: Optional[torch.Tensor] = None,
357
379
  token_type_ids: Optional[torch.Tensor] = None,
380
+ local_block_tables: Optional[torch.Tensor] = None,
358
381
  ) -> torch.FloatTensor:
359
382
  """
360
383
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -375,39 +398,47 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
375
398
  )
376
399
 
377
400
  # Process input in chunks of size `prefill_chunk_size`
378
- for step in range(0, query_length, self.prefill_chunk_size):
401
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
379
402
  # Extract the current chunk of inputs and cache positions
380
- input_chunk = inputs[:, step : step + self.prefill_chunk_size]
381
- cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
403
+ input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
404
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
382
405
  position_ids_chunk = (
383
- position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
406
+ position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
407
+ if position_ids is not None
408
+ else None
384
409
  )
385
410
  if position_embed is not None:
386
- position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
411
+ position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
387
412
 
388
- if self.use_attention_mask and not self.use_position_ids:
413
+ if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
389
414
  # Update attention mask to ensure proper causal behavior
390
- if step >= self.prefill_chunk_size:
391
- chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
392
- chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
415
+ if step >= self.rbln_config.prefill_chunk_size:
416
+ chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
417
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
393
418
 
394
419
  # Define query position
395
- query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
420
+ if step + self.rbln_config.prefill_chunk_size >= query_length:
421
+ query_position = torch.tensor(
422
+ (query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
423
+ )
424
+ else:
425
+ query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
396
426
 
397
427
  # Forward pass for the current chunk
398
428
  logits = super().forward(
399
429
  input_chunk,
400
430
  cache_pos_chunk,
401
431
  block_tables,
432
+ local_block_tables,
402
433
  position_embed_chunk if position_embed is not None else None,
403
434
  query_position,
404
- chunked_attention_mask if self.use_attention_mask else None,
405
- position_ids_chunk if self.use_position_ids else None,
435
+ chunked_attention_mask if self.rbln_config.use_attention_mask else None,
436
+ position_ids_chunk if self.rbln_config.use_position_ids else None,
406
437
  out=out_buffers,
407
438
  )
408
439
 
409
440
  # Update decoder attention mask with processed KV-cache length from prefill phase
410
- if not is_external_block_tables and self.use_attention_mask:
441
+ if not is_external_block_tables and self.rbln_config.use_attention_mask:
411
442
  self.dec_attn_mask[batch_idx].fill_(0)
412
443
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
413
444
 
@@ -427,6 +458,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
427
458
  This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
428
459
 
429
460
  The class provides core functionality for:
461
+
430
462
  1. Converting pre-trained transformer models to RBLN-optimized format
431
463
  2. Handling the compilation process for RBLN devices
432
464
  3. Managing inference operations for causal language modeling
@@ -477,13 +509,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
477
509
  dec_attn_mask=dec_attn_mask,
478
510
  block_tables=block_tables,
479
511
  free_block_pool=free_block_pool,
480
- kvcache_block_size=self.rbln_config.kvcache_block_size,
512
+ rbln_config=self.rbln_config,
481
513
  vocab_size=self.config.vocab_size,
482
- prefill_chunk_size=self.rbln_config.prefill_chunk_size,
483
- max_seq_len=self.rbln_config.max_seq_len,
484
- use_attention_mask=self.rbln_config.use_attention_mask,
485
- attn_impl=self.rbln_config.attn_impl,
486
- use_position_ids=self.rbln_config.use_position_ids,
487
514
  )
488
515
 
489
516
  self.decoders = {}
@@ -497,10 +524,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
497
524
  dec_attn_mask=dec_attn_mask,
498
525
  block_tables=block_tables,
499
526
  free_block_pool=free_block_pool,
500
- kvcache_block_size=self.rbln_config.kvcache_block_size,
501
- use_attention_mask=self.rbln_config.use_attention_mask,
502
- attn_impl=self.rbln_config.attn_impl,
503
- use_position_ids=self.rbln_config.use_position_ids,
527
+ rbln_config=self.rbln_config,
504
528
  )
505
529
 
506
530
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -509,15 +533,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
509
533
  @classmethod
510
534
  def save_torch_artifacts(
511
535
  cls,
512
- model: "PreTrainedModel",
536
+ model: PreTrainedModel,
513
537
  save_dir_path: Path,
514
538
  subfolder: str,
515
539
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
516
540
  ):
517
- """
518
- If you are unavoidably running on a CPU rather than an RBLN device,
519
- store the torch tensor, weight, etc. in this function.
520
- """
541
+ # If you are unavoidably running on a CPU rather than an RBLN device,
542
+ # store the torch tensor, weight, etc. in this function.
521
543
  if rbln_config.use_inputs_embeds:
522
544
  save_dict = {}
523
545
  save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
@@ -545,7 +567,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
545
567
  def get_quantized_model(
546
568
  cls,
547
569
  model_id: str,
548
- config: Optional["PretrainedConfig"] = None,
570
+ config: Optional[PretrainedConfig] = None,
549
571
  use_auth_token: Optional[Union[bool, str]] = None,
550
572
  revision: Optional[str] = None,
551
573
  force_download: bool = False,
@@ -584,16 +606,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
584
606
  return model
585
607
 
586
608
  def __getattr__(self, __name: str) -> Any:
587
- """
588
- Special method to delegate attribute access to the original Huggingface LM class.
589
- This method is called when an attribute is not found in the current instance's dictionary.
590
- It enables transparent access to the original model's attributes and methods while maintaining
591
- proper method binding.
592
-
593
- The method implements a delegation pattern that:
594
- 1. For methods: Creates a wrapper that properly binds 'self' to method calls
595
- 2. For other attributes: Returns them directly from the original class
596
- """
609
+ # Special method to delegate attribute access to the original Huggingface LM class.
610
+ # This method is called when an attribute is not found in the current instance's dictionary.
611
+ # It enables transparent access to the original model's attributes and methods while maintaining
612
+ # proper method binding.
613
+
614
+ # The method implements a delegation pattern that:
615
+
616
+ # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
617
+ # 2. For other attributes: Returns them directly from the original class
597
618
 
598
619
  def redirect(func):
599
620
  return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
@@ -606,7 +627,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
606
627
  @classmethod
607
628
  def get_pytorch_model(
608
629
  cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
609
- ) -> "PreTrainedModel":
630
+ ) -> PreTrainedModel:
610
631
  if rbln_config and rbln_config.quantization:
611
632
  model = cls.get_quantized_model(*args, **kwargs)
612
633
  else:
@@ -615,7 +636,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
615
636
  return model
616
637
 
617
638
  @classmethod
618
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
639
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
619
640
  wrapper_cfg = {
620
641
  "max_seq_len": rbln_config.max_seq_len,
621
642
  "attn_impl": rbln_config.attn_impl,
@@ -625,12 +646,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
625
646
  "use_attention_mask": rbln_config.use_attention_mask,
626
647
  "use_position_ids": rbln_config.use_position_ids,
627
648
  "use_inputs_embeds": rbln_config.use_inputs_embeds,
649
+ "cache_impl": rbln_config.cache_impl,
650
+ "sliding_window": rbln_config.sliding_window,
651
+ "sliding_window_layers": rbln_config.sliding_window_layers,
628
652
  }
629
653
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
630
654
 
631
655
  @classmethod
632
656
  @torch.inference_mode()
633
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
657
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
634
658
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
635
659
 
636
660
  rbln_compile_configs = rbln_config.compile_cfgs
@@ -655,9 +679,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
655
679
  quantization.maybe_set_quantization_env()
656
680
  original_linear = torch.nn.functional.linear
657
681
  torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
658
- compiled_model = RBLNModel.compile(
682
+ compiled_model = cls.compile(
659
683
  wrapped_model,
660
684
  compile_config,
685
+ create_runtimes=rbln_config.create_runtimes,
686
+ device=rbln_config.device,
661
687
  example_inputs=example_inputs,
662
688
  compile_context=compile_context,
663
689
  )
@@ -709,7 +735,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
709
735
  compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
710
736
  ):
711
737
  alloc_memory_by_key[key] += sum(memory_per_node)
712
- alloc_memory_by_key.pop("PortRecur") # kv-cache
738
+ alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
739
+ alloc_memory_by_key.pop("DramTensor", None) # kv-cache
713
740
  kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
714
741
 
715
742
  # Get the maximum number of blocks that can be allocated
@@ -746,35 +773,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
746
773
  buffer: Optional[int] = None,
747
774
  num_runtimes: int = 2,
748
775
  ) -> int:
749
- """
750
- We are finding max_n_blocks(x) that satisfies the following equation:
751
-
752
- available_dram - kernel_size - buffer
753
- - num_layers * 2 * tensor_parallel_size
754
- * align_2MB(
755
- x
756
- * block_size
757
- * align_64(head_dim)
758
- * math.ceil(num_key_value_heads / tensor_parallel_size)
759
- * 2
760
- ) > 0
761
-
762
- This inequality can be rewritten as follows:
763
-
764
- a - c * align_2MB(b * x) > 0
765
- where
766
- a = available_dram - kernel_size - buffer
767
- b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
768
- c = num_layers * 2 * tensor_parallel_size
769
-
770
- We can rewrite the inequality as follows:
771
- k > align_2MB(b*x)
772
- where
773
- k = a / c
774
-
775
- After that, we can derive the following equation:
776
- x = floor(2**21 / b * floor((k - 1) / 2**21))
777
- """
776
+ # We are finding max_n_blocks(x) that satisfies the following equation:
777
+
778
+ # available_dram - kernel_size - buffer
779
+ # - num_layers * 2 * tensor_parallel_size
780
+ # * align_2MB(
781
+ # x
782
+ # * block_size
783
+ # * align_64(head_dim)
784
+ # * math.ceil(num_key_value_heads / tensor_parallel_size)
785
+ # * 2
786
+ # ) > 0
787
+
788
+ # This inequality can be rewritten as follows:
789
+
790
+ # a - c * align_2MB(b * x) > 0
791
+ # where
792
+ # a = available_dram - kernel_size - buffer
793
+ # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
794
+ # c = num_layers * 2 * tensor_parallel_size
795
+
796
+ # We can rewrite the inequality as follows:
797
+ # k > align_2MB(b*x)
798
+ # where
799
+ # k = a / c
800
+
801
+ # After that, we can derive the following equation:
802
+ # x = floor(2**21 / b * floor((k - 1) / 2**21))
778
803
 
779
804
  def align(x: int, nbytes: int) -> int:
780
805
  return int(math.ceil(x / nbytes) * nbytes)
@@ -833,22 +858,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
833
858
  cls,
834
859
  batch_size: int,
835
860
  query_length: int,
836
- use_inputs_embeds: bool,
837
- use_attention_mask: bool,
838
- use_position_ids: bool,
839
- max_seq_len: int,
840
- kvcache_block_size: int,
841
- kvcache_num_blocks: int,
842
- num_key_value_heads: int,
843
- num_hidden_layers: int,
844
- hidden_size: int,
845
- head_dim: int,
861
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
862
+ model_config: PretrainedConfig,
846
863
  ):
847
- if use_inputs_embeds:
864
+ is_prefill: bool = query_length > 1
865
+ num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
866
+ num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
867
+ num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
868
+ hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
869
+ head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
870
+ local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
871
+
872
+ # 1. main input
873
+ if rbln_config.use_inputs_embeds:
848
874
  main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
849
875
  else:
850
876
  main_input = ("input_ids", [batch_size, query_length], "int64")
851
877
 
878
+ # 2. cache_position
852
879
  input_info = [
853
880
  main_input,
854
881
  (
@@ -858,38 +885,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
858
885
  ),
859
886
  ]
860
887
 
861
- max_block_cnt = max_seq_len // kvcache_block_size
862
-
863
- if query_length > 1:
864
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
865
- else:
866
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
867
-
868
- if query_length > 1:
888
+ # 3. block_tables
889
+ if rbln_config.cache_impl in ["static", "hybrid"]:
890
+ max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
869
891
  input_info.extend(
870
- [
871
- ("query_position", [], "int16"),
872
- ]
892
+ [("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
873
893
  )
874
- if use_attention_mask:
894
+ if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
895
+ input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
896
+
897
+ # 4. query_position
898
+ if is_prefill:
899
+ input_info.extend([("query_position", [], "int16")])
900
+
901
+ # 5. attention_mask & position_ids
902
+ if rbln_config.use_attention_mask:
875
903
  input_info.extend(
876
904
  [
877
- ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
905
+ ("attention_mask", [batch_size, rbln_config.max_seq_len], "float32")
906
+ if rbln_config.use_position_ids
907
+ else ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
878
908
  ]
879
909
  )
880
- if use_position_ids:
910
+ if rbln_config.use_position_ids:
881
911
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
882
912
 
913
+ # 6. past_key_values
914
+ global_kvcache_shape = [
915
+ rbln_config.kvcache_num_blocks,
916
+ num_key_value_heads,
917
+ rbln_config.kvcache_block_size,
918
+ head_dim,
919
+ ]
920
+ local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
883
921
  input_info.extend(
884
922
  [
885
923
  (
886
924
  f"past_key_values_{i}",
887
- [
888
- kvcache_num_blocks,
889
- num_key_value_heads,
890
- kvcache_block_size,
891
- head_dim,
892
- ],
925
+ local_kvcache_shape
926
+ if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
927
+ else global_kvcache_shape,
893
928
  "float32",
894
929
  )
895
930
  for i in range(num_hidden_layers * 2)
@@ -898,12 +933,50 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
898
933
 
899
934
  return input_info
900
935
 
936
+ @classmethod
937
+ def _update_sliding_window_config(
938
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
939
+ ):
940
+ # Update the sliding window configuration for the RBLN model.
941
+
942
+ # This method must be implemented by subclasses to handle their specific sliding window configurations,
943
+ # as Hugging Face models use different configuration keys to represent sliding window layers.
944
+
945
+ # Args:
946
+ # model_config (PretrainedConfig): The model configuration from Hugging Face.
947
+ # rbln_config (RBLNDecoderOnlyModelForCausalLMConfig): The RBLN model configuration.
948
+
949
+ # Notes:
950
+ # Required configuration settings:
951
+ # - `cache_impl`: Must be one of:
952
+ # - "static": All layers use global attention (no sliding window)
953
+ # - "sliding_window": All layers use sliding window attention
954
+ # - "hybrid": A mix of global and sliding window attention layers
955
+ # - `sliding_window`: Width of the sliding window (required if cache_impl is "sliding_window" or "hybrid")
956
+ # - `sliding_window_layers`: List of layer indices using sliding window attention (required if cache_impl is "hybrid")
957
+
958
+ # Example implementation for a 'sliding_window' model:
959
+ # ```python
960
+ # rbln_config.cache_impl = "sliding_window"
961
+ # rbln_config.sliding_window = model_config.sliding_window
962
+ # rbln_config.sliding_window_layers = [i for i in range(model_config.num_hidden_layers)]
963
+ # return rbln_config
964
+ # ```
965
+
966
+ # Returns:
967
+ # RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
968
+
969
+ raise NotImplementedError(
970
+ "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
971
+ "See method docstring for required configuration details."
972
+ )
973
+
901
974
  @classmethod
902
975
  def _update_rbln_config(
903
976
  cls,
904
977
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
905
- model: Optional["PreTrainedModel"] = None,
906
- model_config: Optional["PretrainedConfig"] = None,
978
+ model: Optional[PreTrainedModel] = None,
979
+ model_config: Optional[PretrainedConfig] = None,
907
980
  rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
908
981
  ) -> RBLNDecoderOnlyModelForCausalLMConfig:
909
982
  if rbln_config.max_seq_len is None:
@@ -913,6 +986,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
913
986
  if rbln_config.max_seq_len is None:
914
987
  raise ValueError("`max_seq_len` should be specified.")
915
988
 
989
+ if getattr(model_config, "sliding_window", None) is not None and getattr(
990
+ model_config, "use_sliding_window", True
991
+ ):
992
+ rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
993
+ if rbln_config.sliding_window is not None:
994
+ validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
995
+
916
996
  rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
917
997
  attn_impl=rbln_config.attn_impl,
918
998
  kvcache_partition_len=rbln_config.kvcache_partition_len,
@@ -961,25 +1041,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
961
1041
  "This can cause a failure during model compilation."
962
1042
  )
963
1043
  logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
964
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
965
- num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
966
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
967
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
968
- head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
969
1044
 
970
1045
  prefill_input_info = cls.get_input_info(
971
1046
  batch_size=1,
972
1047
  query_length=rbln_config.prefill_chunk_size,
973
- use_inputs_embeds=rbln_config.use_inputs_embeds,
974
- use_attention_mask=rbln_config.use_attention_mask,
975
- use_position_ids=rbln_config.use_position_ids,
976
- max_seq_len=rbln_config.max_seq_len,
977
- kvcache_block_size=rbln_config.kvcache_block_size,
978
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
979
- num_key_value_heads=num_key_value_heads,
980
- num_hidden_layers=num_hidden_layers,
981
- hidden_size=hidden_size,
982
- head_dim=head_dim,
1048
+ rbln_config=rbln_config,
1049
+ model_config=model_config,
983
1050
  )
984
1051
 
985
1052
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
@@ -989,16 +1056,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
989
1056
  dec_input_info = cls.get_input_info(
990
1057
  batch_size=batch_size,
991
1058
  query_length=1,
992
- use_inputs_embeds=rbln_config.use_inputs_embeds,
993
- use_attention_mask=rbln_config.use_attention_mask,
994
- use_position_ids=rbln_config.use_position_ids,
995
- max_seq_len=rbln_config.max_seq_len,
996
- kvcache_block_size=rbln_config.kvcache_block_size,
997
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
998
- num_key_value_heads=num_key_value_heads,
999
- num_hidden_layers=num_hidden_layers,
1000
- hidden_size=hidden_size,
1001
- head_dim=head_dim,
1059
+ rbln_config=rbln_config,
1060
+ model_config=model_config,
1002
1061
  )
1003
1062
  dec_compile_configs.append(
1004
1063
  RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
@@ -1122,12 +1181,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1122
1181
  return_dict: Optional[torch.Tensor] = None,
1123
1182
  **kwargs,
1124
1183
  ) -> Tuple[torch.FloatTensor]:
1125
- """
1126
- Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
1127
- For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
1128
- A for-loop ensures synchronization with the HuggingFace generate API.
1129
- The decoder stage operates as usual, processing inputs in batch mode.
1130
- """
1184
+ # Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
1185
+ # For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
1186
+ # A for-loop ensures synchronization with the HuggingFace generate API.
1187
+ # The decoder stage operates as usual, processing inputs in batch mode.
1188
+
1131
1189
  # Prefll
1132
1190
  if cache_position is None:
1133
1191
  logits = []