optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (57) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +40 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +31 -14
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +204 -44
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +124 -208
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +565 -366
  13. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  14. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  16. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +0 -6
  17. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +10 -6
  18. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  19. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  20. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  21. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  22. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  23. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  24. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  25. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  26. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  27. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  28. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  29. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  30. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  31. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  32. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  33. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  34. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  35. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  36. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  37. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  38. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  39. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  40. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  41. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  42. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  43. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  44. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  45. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  46. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  47. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  48. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  49. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  50. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  51. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  52. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  53. optimum/rbln/utils/depreacate_utils.py +16 -0
  54. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/METADATA +1 -1
  55. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/RECORD +57 -51
  56. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/WHEEL +0 -0
  57. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -80,12 +80,16 @@ _import_structure = {
80
80
  "RBLNDPTForDepthEstimationConfig",
81
81
  "RBLNExaoneForCausalLM",
82
82
  "RBLNExaoneForCausalLMConfig",
83
+ "RBLNGemmaModel",
84
+ "RBLNGemmaModelConfig",
83
85
  "RBLNGemmaForCausalLM",
84
86
  "RBLNGemmaForCausalLMConfig",
85
87
  "RBLNGemma3ForCausalLM",
86
88
  "RBLNGemma3ForCausalLMConfig",
87
89
  "RBLNGemma3ForConditionalGeneration",
88
90
  "RBLNGemma3ForConditionalGenerationConfig",
91
+ "RBLNGPT2Model",
92
+ "RBLNGPT2ModelConfig",
89
93
  "RBLNGPT2LMHeadModel",
90
94
  "RBLNGPT2LMHeadModelConfig",
91
95
  "RBLNIdefics3VisionTransformer",
@@ -94,22 +98,36 @@ _import_structure = {
94
98
  "RBLNIdefics3VisionTransformerConfig",
95
99
  "RBLNLlamaForCausalLM",
96
100
  "RBLNLlamaForCausalLMConfig",
101
+ "RBLNLlamaModel",
102
+ "RBLNLlamaModelConfig",
97
103
  "RBLNOPTForCausalLM",
98
104
  "RBLNOPTForCausalLMConfig",
99
105
  "RBLNLlavaNextForConditionalGeneration",
100
106
  "RBLNLlavaNextForConditionalGenerationConfig",
101
107
  "RBLNMidmLMHeadModel",
102
108
  "RBLNMidmLMHeadModelConfig",
109
+ "RBLNMistralModel",
110
+ "RBLNMistralModelConfig",
103
111
  "RBLNMistralForCausalLM",
104
112
  "RBLNMistralForCausalLMConfig",
113
+ "RBLNOPTModel",
114
+ "RBLNOPTModelConfig",
115
+ "RBLNPegasusForConditionalGeneration",
116
+ "RBLNPegasusForConditionalGenerationConfig",
117
+ "RBLNPegasusModel",
118
+ "RBLNPegasusModelConfig",
105
119
  "RBLNPhiForCausalLM",
106
120
  "RBLNPhiForCausalLMConfig",
121
+ "RBLNPhiModel",
122
+ "RBLNPhiModelConfig",
107
123
  "RBLNQwen2ForCausalLM",
108
124
  "RBLNQwen2ForCausalLMConfig",
109
125
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
110
126
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
111
127
  "RBLNQwen2_5_VLForConditionalGeneration",
112
128
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
129
+ "RBLNQwen2Model",
130
+ "RBLNQwen2ModelConfig",
113
131
  "RBLNQwen3ForCausalLM",
114
132
  "RBLNQwen3ForCausalLMConfig",
115
133
  "RBLNQwen3Model",
@@ -337,30 +355,48 @@ if TYPE_CHECKING:
337
355
  RBLNGemma3ForConditionalGenerationConfig,
338
356
  RBLNGemmaForCausalLM,
339
357
  RBLNGemmaForCausalLMConfig,
358
+ RBLNGemmaModel,
359
+ RBLNGemmaModelConfig,
340
360
  RBLNGPT2LMHeadModel,
341
361
  RBLNGPT2LMHeadModelConfig,
362
+ RBLNGPT2Model,
363
+ RBLNGPT2ModelConfig,
342
364
  RBLNIdefics3ForConditionalGeneration,
343
365
  RBLNIdefics3ForConditionalGenerationConfig,
344
366
  RBLNIdefics3VisionTransformer,
345
367
  RBLNIdefics3VisionTransformerConfig,
346
368
  RBLNLlamaForCausalLM,
347
369
  RBLNLlamaForCausalLMConfig,
370
+ RBLNLlamaModel,
371
+ RBLNLlamaModelConfig,
348
372
  RBLNLlavaNextForConditionalGeneration,
349
373
  RBLNLlavaNextForConditionalGenerationConfig,
350
374
  RBLNMidmLMHeadModel,
351
375
  RBLNMidmLMHeadModelConfig,
352
376
  RBLNMistralForCausalLM,
353
377
  RBLNMistralForCausalLMConfig,
378
+ RBLNMistralModel,
379
+ RBLNMistralModelConfig,
354
380
  RBLNOPTForCausalLM,
355
381
  RBLNOPTForCausalLMConfig,
382
+ RBLNOPTModel,
383
+ RBLNOPTModelConfig,
384
+ RBLNPegasusForConditionalGeneration,
385
+ RBLNPegasusForConditionalGenerationConfig,
386
+ RBLNPegasusModel,
387
+ RBLNPegasusModelConfig,
356
388
  RBLNPhiForCausalLM,
357
389
  RBLNPhiForCausalLMConfig,
390
+ RBLNPhiModel,
391
+ RBLNPhiModelConfig,
358
392
  RBLNQwen2_5_VisionTransformerPretrainedModel,
359
393
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
360
394
  RBLNQwen2_5_VLForConditionalGeneration,
361
395
  RBLNQwen2_5_VLForConditionalGenerationConfig,
362
396
  RBLNQwen2ForCausalLM,
363
397
  RBLNQwen2ForCausalLMConfig,
398
+ RBLNQwen2Model,
399
+ RBLNQwen2ModelConfig,
364
400
  RBLNQwen3ForCausalLM,
365
401
  RBLNQwen3ForCausalLMConfig,
366
402
  RBLNQwen3Model,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.8.2a4'
21
- __version_tuple__ = version_tuple = (0, 8, 2, 'a4')
20
+ __version__ = version = '0.8.2a5'
21
+ __version_tuple__ = version_tuple = (0, 8, 2, 'a5')
@@ -23,6 +23,7 @@ import numpy as np
23
23
  import torch
24
24
 
25
25
  from .__version__ import __version__
26
+ from .utils.depreacate_utils import warn_deprecated_npu
26
27
  from .utils.logging import get_logger
27
28
  from .utils.runtime_utils import ContextRblnConfig
28
29
 
@@ -675,6 +676,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
675
676
  compile_cfg.npu = self.npu
676
677
  compile_cfg.tensor_parallel_size = self.tensor_parallel_size
677
678
 
679
+ target_npu = self.npu or next((cfg.npu for cfg in self._compile_cfgs if cfg.npu is not None), None)
680
+ warn_deprecated_npu(target_npu)
681
+
678
682
  def freeze(self):
679
683
  if self._frozen:
680
684
  raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
@@ -22,3 +22,8 @@ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tens
22
22
  # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
23
23
  # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
24
24
  return torch.empty_like(cache)
25
+
26
+
27
+ @rbln_cache_update.register_fake
28
+ def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
29
+ return torch.empty_like(cache)
@@ -23,3 +23,10 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
23
23
  output_shape = list(input.shape[:-1])
24
24
  output_shape += [weight.shape[0]]
25
25
  return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
26
+
27
+
28
+ @linear.register_fake
29
+ def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
30
+ output_shape = list(input.shape[:-1])
31
+ output_shape += [weight.shape[0]]
32
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
@@ -68,6 +68,8 @@ _import_structure = {
68
68
  "RBLNDPTForDepthEstimationConfig",
69
69
  "RBLNExaoneForCausalLM",
70
70
  "RBLNExaoneForCausalLMConfig",
71
+ "RBLNGemmaModel",
72
+ "RBLNGemmaModelConfig",
71
73
  "RBLNGemma3ForCausalLM",
72
74
  "RBLNGemma3ForCausalLMConfig",
73
75
  "RBLNGemma3ForConditionalGeneration",
@@ -76,26 +78,44 @@ _import_structure = {
76
78
  "RBLNGemmaForCausalLMConfig",
77
79
  "RBLNGPT2LMHeadModel",
78
80
  "RBLNGPT2LMHeadModelConfig",
81
+ "RBLNGPT2Model",
82
+ "RBLNGPT2ModelConfig",
79
83
  "RBLNIdefics3ForConditionalGeneration",
80
84
  "RBLNIdefics3ForConditionalGenerationConfig",
81
85
  "RBLNIdefics3VisionTransformer",
82
86
  "RBLNIdefics3VisionTransformerConfig",
83
87
  "RBLNLlamaForCausalLM",
84
88
  "RBLNLlamaForCausalLMConfig",
89
+ "RBLNLlamaModel",
90
+ "RBLNLlamaModelConfig",
91
+ "RBLNOPTForCausalLM",
92
+ "RBLNOPTForCausalLMConfig",
93
+ "RBLNPegasusForConditionalGeneration",
94
+ "RBLNPegasusForConditionalGenerationConfig",
95
+ "RBLNPegasusModel",
96
+ "RBLNPegasusModelConfig",
85
97
  "RBLNLlavaNextForConditionalGeneration",
86
98
  "RBLNLlavaNextForConditionalGenerationConfig",
87
99
  "RBLNMidmLMHeadModel",
88
100
  "RBLNMidmLMHeadModelConfig",
89
101
  "RBLNMistralForCausalLM",
90
102
  "RBLNMistralForCausalLMConfig",
103
+ "RBLNMistralModel",
104
+ "RBLNMistralModelConfig",
91
105
  "RBLNOPTForCausalLM",
92
106
  "RBLNOPTForCausalLMConfig",
107
+ "RBLNOPTModel",
108
+ "RBLNOPTModelConfig",
93
109
  "RBLNPhiForCausalLM",
94
110
  "RBLNPhiForCausalLMConfig",
111
+ "RBLNPhiModel",
112
+ "RBLNPhiModelConfig",
95
113
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
96
114
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
97
115
  "RBLNQwen2_5_VLForConditionalGeneration",
98
116
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
117
+ "RBLNQwen2Model",
118
+ "RBLNQwen2ModelConfig",
99
119
  "RBLNQwen2ForCausalLM",
100
120
  "RBLNQwen2ForCausalLMConfig",
101
121
  "RBLNQwen3ForCausalLM",
@@ -170,6 +190,8 @@ if TYPE_CHECKING:
170
190
  RBLNCLIPVisionModelConfig,
171
191
  RBLNCLIPVisionModelWithProjection,
172
192
  RBLNCLIPVisionModelWithProjectionConfig,
193
+ RBLNColPaliForRetrieval,
194
+ RBLNColPaliForRetrievalConfig,
173
195
  RBLNDecoderOnlyModelForCausalLM,
174
196
  RBLNDecoderOnlyModelForCausalLMConfig,
175
197
  RBLNDistilBertForQuestionAnswering,
@@ -184,30 +206,48 @@ if TYPE_CHECKING:
184
206
  RBLNGemma3ForConditionalGenerationConfig,
185
207
  RBLNGemmaForCausalLM,
186
208
  RBLNGemmaForCausalLMConfig,
209
+ RBLNGemmaModel,
210
+ RBLNGemmaModelConfig,
187
211
  RBLNGPT2LMHeadModel,
188
212
  RBLNGPT2LMHeadModelConfig,
213
+ RBLNGPT2Model,
214
+ RBLNGPT2ModelConfig,
189
215
  RBLNIdefics3ForConditionalGeneration,
190
216
  RBLNIdefics3ForConditionalGenerationConfig,
191
217
  RBLNIdefics3VisionTransformer,
192
218
  RBLNIdefics3VisionTransformerConfig,
193
219
  RBLNLlamaForCausalLM,
194
220
  RBLNLlamaForCausalLMConfig,
221
+ RBLNLlamaModel,
222
+ RBLNLlamaModelConfig,
195
223
  RBLNLlavaNextForConditionalGeneration,
196
224
  RBLNLlavaNextForConditionalGenerationConfig,
197
225
  RBLNMidmLMHeadModel,
198
226
  RBLNMidmLMHeadModelConfig,
199
227
  RBLNMistralForCausalLM,
200
228
  RBLNMistralForCausalLMConfig,
229
+ RBLNMistralModel,
230
+ RBLNMistralModelConfig,
201
231
  RBLNOPTForCausalLM,
202
232
  RBLNOPTForCausalLMConfig,
233
+ RBLNOPTModel,
234
+ RBLNOPTModelConfig,
235
+ RBLNPegasusForConditionalGeneration,
236
+ RBLNPegasusForConditionalGenerationConfig,
237
+ RBLNPegasusModel,
238
+ RBLNPegasusModelConfig,
203
239
  RBLNPhiForCausalLM,
204
240
  RBLNPhiForCausalLMConfig,
241
+ RBLNPhiModel,
242
+ RBLNPhiModelConfig,
205
243
  RBLNQwen2_5_VisionTransformerPretrainedModel,
206
244
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
207
245
  RBLNQwen2_5_VLForConditionalGeneration,
208
246
  RBLNQwen2_5_VLForConditionalGenerationConfig,
209
247
  RBLNQwen2ForCausalLM,
210
248
  RBLNQwen2ForCausalLMConfig,
249
+ RBLNQwen2Model,
250
+ RBLNQwen2ModelConfig,
211
251
  RBLNQwen3ForCausalLM,
212
252
  RBLNQwen3ForCausalLMConfig,
213
253
  RBLNQwen3Model,
@@ -0,0 +1,252 @@
1
+ import math
2
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
3
+
4
+ from optimum.rbln.transformers.models.decoderonly.configuration_decoderonly import (
5
+ RBLNDecoderOnlyModelForCausalLMConfig,
6
+ )
7
+
8
+ from ..utils.logging import get_logger
9
+
10
+
11
+ logger = get_logger()
12
+
13
+ if TYPE_CHECKING:
14
+ from rebel import RBLNCompiledModel
15
+ from transformers import PretrainedConfig
16
+
17
+
18
+ DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
19
+ DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
20
+ MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
21
+ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
22
+ MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
23
+ MAX_SLIDING_WINDOW_SIZE = 32_768
24
+
25
+
26
+ def set_default_values(
27
+ attn_impl: Optional[str] = None,
28
+ kvcache_partition_len: Optional[int] = None,
29
+ kvcache_block_size: Optional[int] = None,
30
+ max_seq_len: Optional[int] = None,
31
+ ) -> Tuple[str, int, int]:
32
+ if attn_impl is None:
33
+ attn_impl = "eager"
34
+
35
+ if kvcache_partition_len is not None:
36
+ if attn_impl == "eager":
37
+ attn_impl = "flash_attn"
38
+ logger.warning(
39
+ "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
40
+ "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
41
+ "`attn_impl` has been automatically switched to 'flash_attn'."
42
+ )
43
+
44
+ if kvcache_partition_len is None and attn_impl == "flash_attn":
45
+ kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
46
+
47
+ if kvcache_block_size is None:
48
+ if attn_impl == "eager":
49
+ kvcache_block_size = max_seq_len
50
+ else:
51
+ kvcache_block_size = kvcache_partition_len
52
+
53
+ return attn_impl, kvcache_partition_len, kvcache_block_size
54
+
55
+
56
+ def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
57
+ if attn_impl not in ["eager", "flash_attn"]:
58
+ raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
59
+
60
+ ## Checking Constraints...
61
+ # Constraint of eager attention:
62
+ # - `max_seq_len` <= 32k
63
+
64
+ # Constraints of flash attention:
65
+ # 1. `max_seq_len` should be multiple of `partition_len`.
66
+ # 2. 4k <= `partition_len` <= 32k.
67
+ # 3. `max_seq_len` should be larger then 8k.
68
+ if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
69
+ raise ValueError(
70
+ f"`max_seq_len` is set to {max_seq_len}, "
71
+ f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
72
+ f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
73
+ " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
74
+ )
75
+
76
+ if attn_impl == "flash_attn":
77
+ if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
78
+ raise ValueError(
79
+ f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
80
+ f"when using 'flash_attn'. Please adjust either value to meet this requirement."
81
+ )
82
+ elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
83
+ raise ValueError(
84
+ f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
85
+ f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
86
+ f"Please provide a valid value within this range."
87
+ )
88
+ elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
89
+ raise ValueError(
90
+ f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
91
+ f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
92
+ "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
93
+ )
94
+
95
+ if kvcache_block_size is not None:
96
+ if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
97
+ raise ValueError(
98
+ f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
99
+ f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
100
+ )
101
+ elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
102
+ raise ValueError(
103
+ f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
104
+ f"must always be set equal to the `max_seq_len` {max_seq_len}."
105
+ )
106
+
107
+
108
+ def validate_sliding_window(rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
109
+ if rbln_config.sliding_window > MAX_SLIDING_WINDOW_SIZE - rbln_config.prefill_chunk_size:
110
+ raise ValueError(
111
+ f"Sliding window size ({rbln_config.sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - rbln_config.prefill_chunk_size})"
112
+ )
113
+
114
+ if rbln_config.cache_impl == "sliding_window" and rbln_config.use_attention_mask:
115
+ raise ValueError("`use_attention_mask` must be set to False when `cache_impl` is set to 'sliding_window'.")
116
+
117
+
118
+ class RBLNDecoderOnlyFlashAttentionMixin:
119
+ @classmethod
120
+ def get_maximum_num_blocks(
121
+ cls,
122
+ config: "PretrainedConfig",
123
+ tensor_parallel_size: int,
124
+ kvcache_block_size: int,
125
+ nbits_per_param: Optional[int] = None,
126
+ n_model_params: Optional[int] = None,
127
+ kernel_size: Optional[int] = None,
128
+ buffer: Optional[int] = None,
129
+ num_runtimes: int = 2,
130
+ ) -> int:
131
+ # We are finding max_n_blocks(x) that satisfies the following equation:
132
+
133
+ # available_dram - kernel_size - buffer
134
+ # - num_layers * 2 * tensor_parallel_size
135
+ # * align_2MB(
136
+ # x
137
+ # * block_size
138
+ # * align_64(head_dim)
139
+ # * math.ceil(num_key_value_heads / tensor_parallel_size)
140
+ # * 2
141
+ # ) > 0
142
+
143
+ # This inequality can be rewritten as follows:
144
+
145
+ # a - c * align_2MB(b * x) > 0
146
+ # where
147
+ # a = available_dram - kernel_size - buffer
148
+ # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
149
+ # c = num_layers * 2 * tensor_parallel_size
150
+
151
+ # We can rewrite the inequality as follows:
152
+ # k > align_2MB(b*x)
153
+ # where
154
+ # k = a / c
155
+
156
+ # After that, we can derive the following equation:
157
+ # x = floor(2**21 / b * floor((k - 1) / 2**21))
158
+
159
+ def align(x: int, nbytes: int) -> int:
160
+ return int(math.ceil(x / nbytes) * nbytes)
161
+
162
+ def align_2MB(x: int) -> int:
163
+ return align(x, 2**21)
164
+
165
+ num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
166
+ num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
167
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
168
+ vocab_size = config.vocab_size
169
+ hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
170
+ num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
171
+
172
+ # TODO(jongho): Update if target npu is REBEL.
173
+ ATOM_DRAM_NBYTES = 16 * 2**30
174
+ ATOM_SYS_DRAM_NBYTES = 288 * 2**20
175
+ available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
176
+
177
+ if kernel_size is None:
178
+ if n_model_params is None:
179
+ raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
180
+ # Get estimated kernel size (approximated)
181
+ lm_heads_params = align(vocab_size, 64) * hidden_size
182
+ lm_heads_nbytes = (
183
+ align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
184
+ )
185
+ params = n_model_params - lm_heads_params
186
+ layer_nbytes = (
187
+ align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
188
+ * num_layers
189
+ * tensor_parallel_size
190
+ )
191
+ kernel_size = layer_nbytes + lm_heads_nbytes
192
+ elif n_model_params is not None:
193
+ raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
194
+
195
+ available_dram -= kernel_size
196
+
197
+ if buffer is None:
198
+ # TODO: Accurate buffer estimation
199
+ buffer_per_runtime_per_core = 2**28 # 256MB per runtime
200
+ buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
201
+ buffer = buffer_per_core * tensor_parallel_size
202
+ available_dram -= buffer
203
+
204
+ b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
205
+ c = num_layers * 2 * tensor_parallel_size
206
+ k = available_dram / c
207
+ max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
208
+
209
+ return max_n_blocks
210
+
211
+ @classmethod
212
+ def maybe_suggest_kvcache_num_blocks(
213
+ cls,
214
+ compiled_models: Dict[str, "RBLNCompiledModel"],
215
+ model_config: "PretrainedConfig",
216
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
217
+ ) -> None:
218
+ # Get the actual memory allocation of each node by key
219
+ alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
220
+ alloc_memory_by_key: Dict[str, int] = {
221
+ key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
222
+ }
223
+ for batch_size in rbln_config.decoder_batch_sizes:
224
+ for key, memory_per_node in (
225
+ compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
226
+ ):
227
+ alloc_memory_by_key[key] += sum(memory_per_node)
228
+ alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
229
+ alloc_memory_by_key.pop("DramTensor", None) # kv-cache
230
+ kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
231
+
232
+ # Get the maximum number of blocks that can be allocated
233
+ buffer = sum(alloc_memory_by_key.values())
234
+ max_num_blocks = cls.get_maximum_num_blocks(
235
+ config=model_config,
236
+ tensor_parallel_size=rbln_config.tensor_parallel_size,
237
+ kvcache_block_size=rbln_config.kvcache_block_size,
238
+ kernel_size=kernel_size,
239
+ buffer=buffer,
240
+ )
241
+
242
+ # Since our estimation logic is not always accurate,
243
+ # users can set `kvcache_num_blocks` to `max_num_blocks`.
244
+ # If the memory is not enough, the model will fail to compile.
245
+ if rbln_config.kvcache_num_blocks < max_num_blocks:
246
+ logger.warning(
247
+ f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
248
+ "Our analysis indicates that additional memory is available for more blocks. "
249
+ f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
250
+ "Please be advised that our memory estimation algorithm has limitations, "
251
+ "and increasing this value may not guarantee successful model compilation."
252
+ )
@@ -92,27 +92,38 @@ _import_structure = {
92
92
  "RBLNDPTForDepthEstimationConfig",
93
93
  ],
94
94
  "exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
95
- "gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
95
+ "gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig", "RBLNGemmaModel", "RBLNGemmaModelConfig"],
96
96
  "gemma3": [
97
97
  "RBLNGemma3ForCausalLM",
98
98
  "RBLNGemma3ForCausalLMConfig",
99
99
  "RBLNGemma3ForConditionalGeneration",
100
100
  "RBLNGemma3ForConditionalGenerationConfig",
101
101
  ],
102
- "gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
102
+ "gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig", "RBLNGPT2Model", "RBLNGPT2ModelConfig"],
103
103
  "idefics3": [
104
104
  "RBLNIdefics3VisionTransformer",
105
105
  "RBLNIdefics3ForConditionalGeneration",
106
106
  "RBLNIdefics3ForConditionalGenerationConfig",
107
107
  "RBLNIdefics3VisionTransformerConfig",
108
108
  ],
109
- "llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig"],
110
- "opt": ["RBLNOPTForCausalLM", "RBLNOPTForCausalLMConfig"],
109
+ "llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig", "RBLNLlamaModel", "RBLNLlamaModelConfig"],
110
+ "opt": ["RBLNOPTForCausalLM", "RBLNOPTForCausalLMConfig", "RBLNOPTModel", "RBLNOPTModelConfig"],
111
+ "pegasus": [
112
+ "RBLNPegasusForConditionalGeneration",
113
+ "RBLNPegasusModel",
114
+ "RBLNPegasusForConditionalGenerationConfig",
115
+ "RBLNPegasusModelConfig",
116
+ ],
111
117
  "llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
112
118
  "midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
113
- "mistral": ["RBLNMistralForCausalLM", "RBLNMistralForCausalLMConfig"],
114
- "phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig"],
115
- "qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig"],
119
+ "mistral": [
120
+ "RBLNMistralForCausalLM",
121
+ "RBLNMistralForCausalLMConfig",
122
+ "RBLNMistralModel",
123
+ "RBLNMistralModelConfig",
124
+ ],
125
+ "phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig", "RBLNPhiModel", "RBLNPhiModelConfig"],
126
+ "qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig", "RBLNQwen2Model", "RBLNQwen2ModelConfig"],
116
127
  "qwen3": ["RBLNQwen3ForCausalLM", "RBLNQwen3ForCausalLMConfig", "RBLNQwen3Model", "RBLNQwen3ModelConfig"],
117
128
  "resnet": ["RBLNResNetForImageClassification", "RBLNResNetForImageClassificationConfig"],
118
129
  "roberta": [
@@ -215,27 +226,33 @@ if TYPE_CHECKING:
215
226
  RBLNDPTForDepthEstimationConfig,
216
227
  )
217
228
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
218
- from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
229
+ from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
219
230
  from .gemma3 import (
220
231
  RBLNGemma3ForCausalLM,
221
232
  RBLNGemma3ForCausalLMConfig,
222
233
  RBLNGemma3ForConditionalGeneration,
223
234
  RBLNGemma3ForConditionalGenerationConfig,
224
235
  )
225
- from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
236
+ from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
226
237
  from .idefics3 import (
227
238
  RBLNIdefics3ForConditionalGeneration,
228
239
  RBLNIdefics3ForConditionalGenerationConfig,
229
240
  RBLNIdefics3VisionTransformer,
230
241
  RBLNIdefics3VisionTransformerConfig,
231
242
  )
232
- from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
243
+ from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig, RBLNLlamaModel, RBLNLlamaModelConfig
233
244
  from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
234
245
  from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
235
- from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig
236
- from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig
237
- from .phi import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig
238
- from .qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2ForCausalLMConfig
246
+ from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig, RBLNMistralModel, RBLNMistralModelConfig
247
+ from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig, RBLNOPTModel, RBLNOPTModelConfig
248
+ from .pegasus import (
249
+ RBLNPegasusForConditionalGeneration,
250
+ RBLNPegasusForConditionalGenerationConfig,
251
+ RBLNPegasusModel,
252
+ RBLNPegasusModelConfig,
253
+ )
254
+ from .phi import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig, RBLNPhiModel, RBLNPhiModelConfig
255
+ from .qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2ForCausalLMConfig, RBLNQwen2Model, RBLNQwen2ModelConfig
239
256
  from .qwen2_5_vl import (
240
257
  RBLNQwen2_5_VisionTransformerPretrainedModel,
241
258
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
@@ -22,5 +22,5 @@ from ....ops import (
22
22
  paged_flash_causal_attn_decode,
23
23
  paged_flash_causal_attn_prefill,
24
24
  )
25
- from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
26
- from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
25
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
26
+ from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM