optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +1 -1
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +30 -12
  53. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  54. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  55. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  56. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  57. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +231 -175
  59. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  60. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  63. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  64. optimum/rbln/transformers/models/exaone/modeling_exaone.py +51 -5
  65. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  66. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  67. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  68. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  69. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  70. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +87 -236
  71. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  72. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  73. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  74. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  75. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  76. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  77. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  78. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +33 -4
  79. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  80. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  81. optimum/rbln/transformers/models/midm/modeling_midm.py +51 -5
  82. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  83. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  84. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  85. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  86. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  87. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  88. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  89. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  90. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  91. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  92. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  93. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +46 -25
  94. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -2
  95. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  96. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  97. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  98. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  99. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  100. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  102. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  104. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  105. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  106. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  107. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  108. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  109. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  110. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  111. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  112. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  114. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  115. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  116. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  117. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  118. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  119. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  120. optimum/rbln/utils/model_utils.py +20 -0
  121. optimum/rbln/utils/submodule.py +6 -8
  122. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/METADATA +1 -1
  123. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/RECORD +127 -114
  124. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  125. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  126. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/WHEEL +0 -0
  127. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -36,6 +36,7 @@ from .decoderonly_architecture import (
36
36
  DecoderOnlyWrapper,
37
37
  set_default_values,
38
38
  validate_attention_method,
39
+ validate_sliding_window_size,
39
40
  )
40
41
 
41
42
 
@@ -56,36 +57,25 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
56
57
  dec_attn_mask: torch.Tensor,
57
58
  block_tables: torch.Tensor,
58
59
  free_block_pool: Deque,
59
- kvcache_block_size: int,
60
- use_attention_mask: bool,
61
- attn_impl: str,
62
- use_position_ids: bool,
60
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
63
61
  **kwargs: Any,
64
62
  ) -> None:
65
63
  super().__init__(runtime, **kwargs)
66
64
  self.phase = phase
67
65
  self.batch_size = batch_size
68
-
69
- # shared data structures between prefill and decode phase
70
- self.use_attention_mask = use_attention_mask
66
+ self.rbln_config = rbln_config
71
67
 
72
68
  # shared tensor between prefill and decode phase
73
69
  self.dec_attn_mask = dec_attn_mask
74
70
  self.block_tables = block_tables
75
71
  self.free_block_pool = free_block_pool
76
- self.use_position_ids = use_position_ids
77
72
 
78
- self.kvcache_block_size = kvcache_block_size
79
73
  self.empty_block = -1
80
- self.attn_impl = attn_impl
81
-
82
74
  if self.phase == "prefill":
83
75
  vocab_size = kwargs.pop("vocab_size")
84
- self.max_seq_len = kwargs.pop("max_seq_len")
85
- self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
86
76
  self.output_size = [1, 1, vocab_size]
87
77
  self.causal_mask = 1 - torch.triu(
88
- torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
78
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
89
79
  )
90
80
 
91
81
  def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
@@ -131,31 +121,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
131
121
  else:
132
122
  raise RuntimeError(NO_BLOCKS_ERROR)
133
123
 
134
- if self.phase == "prefill":
135
- # Track previously used blocks and return them to the free_block_pool and
136
- # reset the current batch's block table to empty blocks
137
- prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
138
- self.free_block_pool.extend(prev_blocks)
139
- self.block_tables[batch_idx].fill_(self.empty_block)
140
-
141
- # Get the start (s) and end (e) positions from cache_position and
142
- # iterate over the cache positions to allocate necessary blocks
143
- s, e = cache_position[0][0].item(), cache_position[0][-1].item()
144
- for position in range(s, e + 1, self.kvcache_block_size):
145
- block_idx = position // self.kvcache_block_size
146
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
147
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
148
- update_block(batch_idx, block_idx)
149
-
150
- return replace_empty_block(self.block_tables[batch_idx])
151
- # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
152
- else:
153
- for b_idx in range(self.batch_size):
154
- position = cache_position[b_idx][0].item()
155
- block_idx = position // self.kvcache_block_size
156
- update_block(b_idx, block_idx)
124
+ def get_global_block_tables(batch_idx: int):
125
+ if self.rbln_config.cache_impl == "sliding_window":
126
+ return None
127
+
128
+ if self.phase == "prefill":
129
+ # Track previously used blocks and return them to the free_block_pool and
130
+ # reset the current batch's block table to empty blocks
131
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
132
+ self.free_block_pool.extend(prev_blocks)
133
+ self.block_tables[batch_idx].fill_(self.empty_block)
134
+
135
+ # Get the start (s) and end (e) positions from cache_position and
136
+ # iterate over the cache positions to allocate necessary blocks
137
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
138
+ for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
139
+ block_idx = position // self.rbln_config.kvcache_block_size
140
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
141
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
142
+ update_block(batch_idx, block_idx)
143
+
144
+ return replace_empty_block(self.block_tables[batch_idx])
145
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
146
+ else:
147
+ for b_idx in range(self.batch_size):
148
+ position = cache_position[b_idx][0].item()
149
+ block_idx = position // self.rbln_config.kvcache_block_size
150
+ update_block(b_idx, block_idx)
151
+
152
+ return replace_empty_block(self.block_tables)
157
153
 
158
- return replace_empty_block(self.block_tables)
154
+ def get_local_block_tables(batch_idx: int):
155
+ if self.rbln_config.cache_impl == "static":
156
+ return None
157
+ else:
158
+ return (
159
+ torch.tensor([batch_idx], dtype=torch.int16)
160
+ if self.phase == "prefill"
161
+ else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
162
+ )
163
+
164
+ return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
165
+
166
+ def is_external_block_tables(
167
+ self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
168
+ ):
169
+ if self.rbln_config.cache_impl == "static" and block_tables is None:
170
+ return False
171
+ elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
172
+ return False
173
+ elif self.rbln_config.cache_impl == "hybrid":
174
+ if (block_tables is not None) != (local_block_tables is not None):
175
+ raise ValueError(
176
+ "Both block_tables and local_block_tables must be provided or neither of them must be provided."
177
+ )
178
+ elif block_tables is None and local_block_tables is None:
179
+ return False
180
+ else:
181
+ return True
159
182
 
160
183
  def forward(
161
184
  self,
@@ -180,11 +203,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
180
203
  else:
181
204
  inputs = inputs_embeds
182
205
 
183
- if block_tables is None:
184
- block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
185
- is_external_block_tables = False
186
- else:
187
- is_external_block_tables = True
206
+ is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
207
+ if not is_external_block_tables:
208
+ block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
188
209
 
189
210
  if self.phase == "decode":
190
211
  return self.decode_forward(
@@ -204,6 +225,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
204
225
  attention_mask,
205
226
  batch_idx,
206
227
  block_tables,
228
+ is_external_block_tables=is_external_block_tables,
207
229
  position_embed=position_embed,
208
230
  token_type_ids=token_type_ids,
209
231
  local_block_tables=local_block_tables,
@@ -229,7 +251,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
229
251
  if batch_size != cache_position.shape[0]:
230
252
  raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
231
253
 
232
- if self.use_attention_mask and attention_mask is None:
254
+ if self.rbln_config.use_attention_mask and attention_mask is None:
233
255
  for b_idx in range(batch_size):
234
256
  decoding_step = cache_position[b_idx].item()
235
257
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
@@ -245,7 +267,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
245
267
 
246
268
  attention_mask = self.dec_attn_mask
247
269
 
248
- if self.batch_size < block_tables.shape[0]:
270
+ if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
249
271
  block_tables = block_tables[: self.batch_size]
250
272
 
251
273
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
@@ -255,9 +277,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
255
277
  inputs,
256
278
  cache_position,
257
279
  block_tables,
280
+ local_block_tables,
258
281
  position_embed,
259
- attention_mask if self.use_attention_mask else None,
260
- position_ids if self.use_position_ids else None,
282
+ attention_mask if self.rbln_config.use_attention_mask else None,
283
+ position_ids if self.rbln_config.use_position_ids else None,
261
284
  )
262
285
 
263
286
  return RBLNDecoderOnlyOutput(logits=logits)
@@ -268,7 +291,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
268
291
  cache_position: torch.Tensor,
269
292
  attention_mask: Optional[torch.Tensor] = None,
270
293
  position_embed: Optional[torch.Tensor] = None,
271
- local_block_tables: Optional[torch.Tensor] = None,
272
294
  token_type_ids: Optional[torch.Tensor] = None,
273
295
  ):
274
296
  """
@@ -283,15 +305,15 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
283
305
  )
284
306
 
285
307
  query_length = inputs.shape[1]
286
- if query_length > self.max_seq_len:
308
+ if query_length > self.rbln_config.max_seq_len:
287
309
  raise ValueError(
288
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
310
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
289
311
  )
290
312
 
291
313
  # Initialize attention mask for chunked processing
292
314
  chunked_attention_mask = (
293
- torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
294
- if self.use_attention_mask
315
+ torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
316
+ if self.rbln_config.use_attention_mask
295
317
  else None
296
318
  )
297
319
 
@@ -305,8 +327,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
305
327
  ]
306
328
 
307
329
  # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
308
- if query_length % self.prefill_chunk_size != 0:
309
- padding_size = (self.prefill_chunk_size - query_length) % self.prefill_chunk_size
330
+ padding_size = 0
331
+ if query_length % self.rbln_config.prefill_chunk_size != 0:
332
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
310
333
  # inputs_embeds
311
334
  if inputs.dim() == 3:
312
335
  inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
@@ -351,10 +374,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
351
374
  attention_mask: Optional[torch.Tensor] = None,
352
375
  batch_idx: int = None,
353
376
  block_tables: torch.Tensor = None,
354
- is_external_block_tables: bool = None,
377
+ is_external_block_tables: bool = False,
355
378
  position_embed: Optional[torch.Tensor] = None,
356
- local_block_tables: Optional[torch.Tensor] = None,
357
379
  token_type_ids: Optional[torch.Tensor] = None,
380
+ local_block_tables: Optional[torch.Tensor] = None,
358
381
  ) -> torch.FloatTensor:
359
382
  """
360
383
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -375,39 +398,47 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
375
398
  )
376
399
 
377
400
  # Process input in chunks of size `prefill_chunk_size`
378
- for step in range(0, query_length, self.prefill_chunk_size):
401
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
379
402
  # Extract the current chunk of inputs and cache positions
380
- input_chunk = inputs[:, step : step + self.prefill_chunk_size]
381
- cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
403
+ input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
404
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
382
405
  position_ids_chunk = (
383
- position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
406
+ position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
407
+ if position_ids is not None
408
+ else None
384
409
  )
385
410
  if position_embed is not None:
386
- position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
411
+ position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
387
412
 
388
- if self.use_attention_mask and not self.use_position_ids:
413
+ if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
389
414
  # Update attention mask to ensure proper causal behavior
390
- if step >= self.prefill_chunk_size:
391
- chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
392
- chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
415
+ if step >= self.rbln_config.prefill_chunk_size:
416
+ chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
417
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
393
418
 
394
419
  # Define query position
395
- query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
420
+ if step + self.rbln_config.prefill_chunk_size >= query_length:
421
+ query_position = torch.tensor(
422
+ (query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
423
+ )
424
+ else:
425
+ query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
396
426
 
397
427
  # Forward pass for the current chunk
398
428
  logits = super().forward(
399
429
  input_chunk,
400
430
  cache_pos_chunk,
401
431
  block_tables,
432
+ local_block_tables,
402
433
  position_embed_chunk if position_embed is not None else None,
403
434
  query_position,
404
- chunked_attention_mask if self.use_attention_mask else None,
405
- position_ids_chunk if self.use_position_ids else None,
435
+ chunked_attention_mask if self.rbln_config.use_attention_mask else None,
436
+ position_ids_chunk if self.rbln_config.use_position_ids else None,
406
437
  out=out_buffers,
407
438
  )
408
439
 
409
440
  # Update decoder attention mask with processed KV-cache length from prefill phase
410
- if not is_external_block_tables and self.use_attention_mask:
441
+ if not is_external_block_tables and self.rbln_config.use_attention_mask:
411
442
  self.dec_attn_mask[batch_idx].fill_(0)
412
443
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
413
444
 
@@ -477,13 +508,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
477
508
  dec_attn_mask=dec_attn_mask,
478
509
  block_tables=block_tables,
479
510
  free_block_pool=free_block_pool,
480
- kvcache_block_size=self.rbln_config.kvcache_block_size,
511
+ rbln_config=self.rbln_config,
481
512
  vocab_size=self.config.vocab_size,
482
- prefill_chunk_size=self.rbln_config.prefill_chunk_size,
483
- max_seq_len=self.rbln_config.max_seq_len,
484
- use_attention_mask=self.rbln_config.use_attention_mask,
485
- attn_impl=self.rbln_config.attn_impl,
486
- use_position_ids=self.rbln_config.use_position_ids,
487
513
  )
488
514
 
489
515
  self.decoders = {}
@@ -497,10 +523,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
497
523
  dec_attn_mask=dec_attn_mask,
498
524
  block_tables=block_tables,
499
525
  free_block_pool=free_block_pool,
500
- kvcache_block_size=self.rbln_config.kvcache_block_size,
501
- use_attention_mask=self.rbln_config.use_attention_mask,
502
- attn_impl=self.rbln_config.attn_impl,
503
- use_position_ids=self.rbln_config.use_position_ids,
526
+ rbln_config=self.rbln_config,
504
527
  )
505
528
 
506
529
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -514,10 +537,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
514
537
  subfolder: str,
515
538
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
516
539
  ):
517
- """
518
- If you are unavoidably running on a CPU rather than an RBLN device,
519
- store the torch tensor, weight, etc. in this function.
520
- """
540
+ # If you are unavoidably running on a CPU rather than an RBLN device,
541
+ # store the torch tensor, weight, etc. in this function.
521
542
  if rbln_config.use_inputs_embeds:
522
543
  save_dict = {}
523
544
  save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
@@ -625,6 +646,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
625
646
  "use_attention_mask": rbln_config.use_attention_mask,
626
647
  "use_position_ids": rbln_config.use_position_ids,
627
648
  "use_inputs_embeds": rbln_config.use_inputs_embeds,
649
+ "cache_impl": rbln_config.cache_impl,
650
+ "sliding_window": rbln_config.sliding_window,
651
+ "sliding_window_layers": rbln_config.sliding_window_layers,
628
652
  }
629
653
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
630
654
 
@@ -709,7 +733,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
709
733
  compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
710
734
  ):
711
735
  alloc_memory_by_key[key] += sum(memory_per_node)
712
- alloc_memory_by_key.pop("PortRecur") # kv-cache
736
+ alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
737
+ alloc_memory_by_key.pop("DramTensor", None) # kv-cache
713
738
  kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
714
739
 
715
740
  # Get the maximum number of blocks that can be allocated
@@ -746,35 +771,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
746
771
  buffer: Optional[int] = None,
747
772
  num_runtimes: int = 2,
748
773
  ) -> int:
749
- """
750
- We are finding max_n_blocks(x) that satisfies the following equation:
751
-
752
- available_dram - kernel_size - buffer
753
- - num_layers * 2 * tensor_parallel_size
754
- * align_2MB(
755
- x
756
- * block_size
757
- * align_64(head_dim)
758
- * math.ceil(num_key_value_heads / tensor_parallel_size)
759
- * 2
760
- ) > 0
761
-
762
- This inequality can be rewritten as follows:
763
-
764
- a - c * align_2MB(b * x) > 0
765
- where
766
- a = available_dram - kernel_size - buffer
767
- b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
768
- c = num_layers * 2 * tensor_parallel_size
769
-
770
- We can rewrite the inequality as follows:
771
- k > align_2MB(b*x)
772
- where
773
- k = a / c
774
-
775
- After that, we can derive the following equation:
776
- x = floor(2**21 / b * floor((k - 1) / 2**21))
777
- """
774
+ # We are finding max_n_blocks(x) that satisfies the following equation:
775
+
776
+ # available_dram - kernel_size - buffer
777
+ # - num_layers * 2 * tensor_parallel_size
778
+ # * align_2MB(
779
+ # x
780
+ # * block_size
781
+ # * align_64(head_dim)
782
+ # * math.ceil(num_key_value_heads / tensor_parallel_size)
783
+ # * 2
784
+ # ) > 0
785
+
786
+ # This inequality can be rewritten as follows:
787
+
788
+ # a - c * align_2MB(b * x) > 0
789
+ # where
790
+ # a = available_dram - kernel_size - buffer
791
+ # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
792
+ # c = num_layers * 2 * tensor_parallel_size
793
+
794
+ # We can rewrite the inequality as follows:
795
+ # k > align_2MB(b*x)
796
+ # where
797
+ # k = a / c
798
+
799
+ # After that, we can derive the following equation:
800
+ # x = floor(2**21 / b * floor((k - 1) / 2**21))
778
801
 
779
802
  def align(x: int, nbytes: int) -> int:
780
803
  return int(math.ceil(x / nbytes) * nbytes)
@@ -833,22 +856,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
833
856
  cls,
834
857
  batch_size: int,
835
858
  query_length: int,
836
- use_inputs_embeds: bool,
837
- use_attention_mask: bool,
838
- use_position_ids: bool,
839
- max_seq_len: int,
840
- kvcache_block_size: int,
841
- kvcache_num_blocks: int,
842
- num_key_value_heads: int,
843
- num_hidden_layers: int,
844
- hidden_size: int,
845
- head_dim: int,
859
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
860
+ model_config: PretrainedConfig,
846
861
  ):
847
- if use_inputs_embeds:
862
+ is_prefill: bool = query_length > 1
863
+ num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
864
+ num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
865
+ num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
866
+ hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
867
+ head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
868
+ local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
869
+
870
+ # 1. main input
871
+ if rbln_config.use_inputs_embeds:
848
872
  main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
849
873
  else:
850
874
  main_input = ("input_ids", [batch_size, query_length], "int64")
851
875
 
876
+ # 2. cache_position
852
877
  input_info = [
853
878
  main_input,
854
879
  (
@@ -858,38 +883,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
858
883
  ),
859
884
  ]
860
885
 
861
- max_block_cnt = max_seq_len // kvcache_block_size
862
-
863
- if query_length > 1:
864
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
865
- else:
866
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
867
-
868
- if query_length > 1:
886
+ # 3. block_tables
887
+ if rbln_config.cache_impl in ["static", "hybrid"]:
888
+ max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
869
889
  input_info.extend(
870
- [
871
- ("query_position", [], "int16"),
872
- ]
890
+ [("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
873
891
  )
874
- if use_attention_mask:
892
+ if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
893
+ input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
894
+
895
+ # 4. query_position
896
+ if is_prefill:
897
+ input_info.extend([("query_position", [], "int16")])
898
+
899
+ # 5. attention_mask & position_ids
900
+ if rbln_config.use_attention_mask:
875
901
  input_info.extend(
876
902
  [
877
- ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
903
+ ("attention_mask", [batch_size, rbln_config.max_seq_len], "float32")
904
+ if rbln_config.use_position_ids
905
+ else ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
878
906
  ]
879
907
  )
880
- if use_position_ids:
908
+ if rbln_config.use_position_ids:
881
909
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
882
910
 
911
+ # 6. past_key_values
912
+ global_kvcache_shape = [
913
+ rbln_config.kvcache_num_blocks,
914
+ num_key_value_heads,
915
+ rbln_config.kvcache_block_size,
916
+ head_dim,
917
+ ]
918
+ local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
883
919
  input_info.extend(
884
920
  [
885
921
  (
886
922
  f"past_key_values_{i}",
887
- [
888
- kvcache_num_blocks,
889
- num_key_value_heads,
890
- kvcache_block_size,
891
- head_dim,
892
- ],
923
+ local_kvcache_shape
924
+ if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
925
+ else global_kvcache_shape,
893
926
  "float32",
894
927
  )
895
928
  for i in range(num_hidden_layers * 2)
@@ -898,6 +931,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
898
931
 
899
932
  return input_info
900
933
 
934
+ @classmethod
935
+ def _update_sliding_window_config(
936
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
937
+ ):
938
+ # Update the sliding window configuration for the RBLN model.
939
+
940
+ # This method must be implemented by subclasses to handle their specific sliding window configurations,
941
+ # as Hugging Face models use different configuration keys to represent sliding window layers.
942
+
943
+ # Args:
944
+ # model_config (PretrainedConfig): The model configuration from Hugging Face.
945
+ # rbln_config (RBLNDecoderOnlyModelForCausalLMConfig): The RBLN model configuration.
946
+
947
+ # Notes:
948
+ # Required configuration settings:
949
+ # - `cache_impl`: Must be one of:
950
+ # - "static": All layers use global attention (no sliding window)
951
+ # - "sliding_window": All layers use sliding window attention
952
+ # - "hybrid": A mix of global and sliding window attention layers
953
+ # - `sliding_window`: Width of the sliding window (required if cache_impl is "sliding_window" or "hybrid")
954
+ # - `sliding_window_layers`: List of layer indices using sliding window attention (required if cache_impl is "hybrid")
955
+
956
+ # Example implementation for a 'sliding_window' model:
957
+ # ```python
958
+ # rbln_config.cache_impl = "sliding_window"
959
+ # rbln_config.sliding_window = model_config.sliding_window
960
+ # rbln_config.sliding_window_layers = [i for i in range(model_config.num_hidden_layers)]
961
+ # return rbln_config
962
+ # ```
963
+
964
+ # Returns:
965
+ # RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
966
+
967
+ raise NotImplementedError(
968
+ "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
969
+ "See method docstring for required configuration details."
970
+ )
971
+
901
972
  @classmethod
902
973
  def _update_rbln_config(
903
974
  cls,
@@ -913,6 +984,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
913
984
  if rbln_config.max_seq_len is None:
914
985
  raise ValueError("`max_seq_len` should be specified.")
915
986
 
987
+ if getattr(model_config, "sliding_window", None) is not None and getattr(
988
+ model_config, "use_sliding_window", True
989
+ ):
990
+ rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
991
+ if rbln_config.sliding_window is not None:
992
+ validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
993
+
916
994
  rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
917
995
  attn_impl=rbln_config.attn_impl,
918
996
  kvcache_partition_len=rbln_config.kvcache_partition_len,
@@ -961,25 +1039,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
961
1039
  "This can cause a failure during model compilation."
962
1040
  )
963
1041
  logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
964
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
965
- num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
966
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
967
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
968
- head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
969
1042
 
970
1043
  prefill_input_info = cls.get_input_info(
971
1044
  batch_size=1,
972
1045
  query_length=rbln_config.prefill_chunk_size,
973
- use_inputs_embeds=rbln_config.use_inputs_embeds,
974
- use_attention_mask=rbln_config.use_attention_mask,
975
- use_position_ids=rbln_config.use_position_ids,
976
- max_seq_len=rbln_config.max_seq_len,
977
- kvcache_block_size=rbln_config.kvcache_block_size,
978
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
979
- num_key_value_heads=num_key_value_heads,
980
- num_hidden_layers=num_hidden_layers,
981
- hidden_size=hidden_size,
982
- head_dim=head_dim,
1046
+ rbln_config=rbln_config,
1047
+ model_config=model_config,
983
1048
  )
984
1049
 
985
1050
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
@@ -989,16 +1054,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
989
1054
  dec_input_info = cls.get_input_info(
990
1055
  batch_size=batch_size,
991
1056
  query_length=1,
992
- use_inputs_embeds=rbln_config.use_inputs_embeds,
993
- use_attention_mask=rbln_config.use_attention_mask,
994
- use_position_ids=rbln_config.use_position_ids,
995
- max_seq_len=rbln_config.max_seq_len,
996
- kvcache_block_size=rbln_config.kvcache_block_size,
997
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
998
- num_key_value_heads=num_key_value_heads,
999
- num_hidden_layers=num_hidden_layers,
1000
- hidden_size=hidden_size,
1001
- head_dim=head_dim,
1057
+ rbln_config=rbln_config,
1058
+ model_config=model_config,
1002
1059
  )
1003
1060
  dec_compile_configs.append(
1004
1061
  RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
@@ -1122,12 +1179,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1122
1179
  return_dict: Optional[torch.Tensor] = None,
1123
1180
  **kwargs,
1124
1181
  ) -> Tuple[torch.FloatTensor]:
1125
- """
1126
- Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
1127
- For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
1128
- A for-loop ensures synchronization with the HuggingFace generate API.
1129
- The decoder stage operates as usual, processing inputs in batch mode.
1130
- """
1182
+ # Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
1183
+ # For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
1184
+ # A for-loop ensures synchronization with the HuggingFace generate API.
1185
+ # The decoder stage operates as usual, processing inputs in batch mode.
1186
+
1131
1187
  # Prefll
1132
1188
  if cache_position is None:
1133
1189
  logits = []