optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +4 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +40 -0
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/models/__init__.py +31 -14
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +204 -44
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +124 -208
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +565 -366
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +0 -6
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +10 -6
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
- optimum/rbln/utils/depreacate_utils.py +16 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/RECORD +57 -51
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a5.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -80,12 +80,16 @@ _import_structure = {
|
|
|
80
80
|
"RBLNDPTForDepthEstimationConfig",
|
|
81
81
|
"RBLNExaoneForCausalLM",
|
|
82
82
|
"RBLNExaoneForCausalLMConfig",
|
|
83
|
+
"RBLNGemmaModel",
|
|
84
|
+
"RBLNGemmaModelConfig",
|
|
83
85
|
"RBLNGemmaForCausalLM",
|
|
84
86
|
"RBLNGemmaForCausalLMConfig",
|
|
85
87
|
"RBLNGemma3ForCausalLM",
|
|
86
88
|
"RBLNGemma3ForCausalLMConfig",
|
|
87
89
|
"RBLNGemma3ForConditionalGeneration",
|
|
88
90
|
"RBLNGemma3ForConditionalGenerationConfig",
|
|
91
|
+
"RBLNGPT2Model",
|
|
92
|
+
"RBLNGPT2ModelConfig",
|
|
89
93
|
"RBLNGPT2LMHeadModel",
|
|
90
94
|
"RBLNGPT2LMHeadModelConfig",
|
|
91
95
|
"RBLNIdefics3VisionTransformer",
|
|
@@ -94,22 +98,36 @@ _import_structure = {
|
|
|
94
98
|
"RBLNIdefics3VisionTransformerConfig",
|
|
95
99
|
"RBLNLlamaForCausalLM",
|
|
96
100
|
"RBLNLlamaForCausalLMConfig",
|
|
101
|
+
"RBLNLlamaModel",
|
|
102
|
+
"RBLNLlamaModelConfig",
|
|
97
103
|
"RBLNOPTForCausalLM",
|
|
98
104
|
"RBLNOPTForCausalLMConfig",
|
|
99
105
|
"RBLNLlavaNextForConditionalGeneration",
|
|
100
106
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
|
101
107
|
"RBLNMidmLMHeadModel",
|
|
102
108
|
"RBLNMidmLMHeadModelConfig",
|
|
109
|
+
"RBLNMistralModel",
|
|
110
|
+
"RBLNMistralModelConfig",
|
|
103
111
|
"RBLNMistralForCausalLM",
|
|
104
112
|
"RBLNMistralForCausalLMConfig",
|
|
113
|
+
"RBLNOPTModel",
|
|
114
|
+
"RBLNOPTModelConfig",
|
|
115
|
+
"RBLNPegasusForConditionalGeneration",
|
|
116
|
+
"RBLNPegasusForConditionalGenerationConfig",
|
|
117
|
+
"RBLNPegasusModel",
|
|
118
|
+
"RBLNPegasusModelConfig",
|
|
105
119
|
"RBLNPhiForCausalLM",
|
|
106
120
|
"RBLNPhiForCausalLMConfig",
|
|
121
|
+
"RBLNPhiModel",
|
|
122
|
+
"RBLNPhiModelConfig",
|
|
107
123
|
"RBLNQwen2ForCausalLM",
|
|
108
124
|
"RBLNQwen2ForCausalLMConfig",
|
|
109
125
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
|
110
126
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
111
127
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
112
128
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
129
|
+
"RBLNQwen2Model",
|
|
130
|
+
"RBLNQwen2ModelConfig",
|
|
113
131
|
"RBLNQwen3ForCausalLM",
|
|
114
132
|
"RBLNQwen3ForCausalLMConfig",
|
|
115
133
|
"RBLNQwen3Model",
|
|
@@ -337,30 +355,48 @@ if TYPE_CHECKING:
|
|
|
337
355
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
338
356
|
RBLNGemmaForCausalLM,
|
|
339
357
|
RBLNGemmaForCausalLMConfig,
|
|
358
|
+
RBLNGemmaModel,
|
|
359
|
+
RBLNGemmaModelConfig,
|
|
340
360
|
RBLNGPT2LMHeadModel,
|
|
341
361
|
RBLNGPT2LMHeadModelConfig,
|
|
362
|
+
RBLNGPT2Model,
|
|
363
|
+
RBLNGPT2ModelConfig,
|
|
342
364
|
RBLNIdefics3ForConditionalGeneration,
|
|
343
365
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
344
366
|
RBLNIdefics3VisionTransformer,
|
|
345
367
|
RBLNIdefics3VisionTransformerConfig,
|
|
346
368
|
RBLNLlamaForCausalLM,
|
|
347
369
|
RBLNLlamaForCausalLMConfig,
|
|
370
|
+
RBLNLlamaModel,
|
|
371
|
+
RBLNLlamaModelConfig,
|
|
348
372
|
RBLNLlavaNextForConditionalGeneration,
|
|
349
373
|
RBLNLlavaNextForConditionalGenerationConfig,
|
|
350
374
|
RBLNMidmLMHeadModel,
|
|
351
375
|
RBLNMidmLMHeadModelConfig,
|
|
352
376
|
RBLNMistralForCausalLM,
|
|
353
377
|
RBLNMistralForCausalLMConfig,
|
|
378
|
+
RBLNMistralModel,
|
|
379
|
+
RBLNMistralModelConfig,
|
|
354
380
|
RBLNOPTForCausalLM,
|
|
355
381
|
RBLNOPTForCausalLMConfig,
|
|
382
|
+
RBLNOPTModel,
|
|
383
|
+
RBLNOPTModelConfig,
|
|
384
|
+
RBLNPegasusForConditionalGeneration,
|
|
385
|
+
RBLNPegasusForConditionalGenerationConfig,
|
|
386
|
+
RBLNPegasusModel,
|
|
387
|
+
RBLNPegasusModelConfig,
|
|
356
388
|
RBLNPhiForCausalLM,
|
|
357
389
|
RBLNPhiForCausalLMConfig,
|
|
390
|
+
RBLNPhiModel,
|
|
391
|
+
RBLNPhiModelConfig,
|
|
358
392
|
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
|
359
393
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
360
394
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
361
395
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
362
396
|
RBLNQwen2ForCausalLM,
|
|
363
397
|
RBLNQwen2ForCausalLMConfig,
|
|
398
|
+
RBLNQwen2Model,
|
|
399
|
+
RBLNQwen2ModelConfig,
|
|
364
400
|
RBLNQwen3ForCausalLM,
|
|
365
401
|
RBLNQwen3ForCausalLMConfig,
|
|
366
402
|
RBLNQwen3Model,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '0.8.
|
|
21
|
-
__version_tuple__ = version_tuple = (0, 8, 2, '
|
|
20
|
+
__version__ = version = '0.8.2a5'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 8, 2, 'a5')
|
|
@@ -23,6 +23,7 @@ import numpy as np
|
|
|
23
23
|
import torch
|
|
24
24
|
|
|
25
25
|
from .__version__ import __version__
|
|
26
|
+
from .utils.depreacate_utils import warn_deprecated_npu
|
|
26
27
|
from .utils.logging import get_logger
|
|
27
28
|
from .utils.runtime_utils import ContextRblnConfig
|
|
28
29
|
|
|
@@ -675,6 +676,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
675
676
|
compile_cfg.npu = self.npu
|
|
676
677
|
compile_cfg.tensor_parallel_size = self.tensor_parallel_size
|
|
677
678
|
|
|
679
|
+
target_npu = self.npu or next((cfg.npu for cfg in self._compile_cfgs if cfg.npu is not None), None)
|
|
680
|
+
warn_deprecated_npu(target_npu)
|
|
681
|
+
|
|
678
682
|
def freeze(self):
|
|
679
683
|
if self._frozen:
|
|
680
684
|
raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
|
|
@@ -22,3 +22,8 @@ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tens
|
|
|
22
22
|
# This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
|
|
23
23
|
# The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
|
|
24
24
|
return torch.empty_like(cache)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@rbln_cache_update.register_fake
|
|
28
|
+
def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
|
|
29
|
+
return torch.empty_like(cache)
|
optimum/rbln/ops/linear.py
CHANGED
|
@@ -23,3 +23,10 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
|
|
|
23
23
|
output_shape = list(input.shape[:-1])
|
|
24
24
|
output_shape += [weight.shape[0]]
|
|
25
25
|
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@linear.register_fake
|
|
29
|
+
def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
|
|
30
|
+
output_shape = list(input.shape[:-1])
|
|
31
|
+
output_shape += [weight.shape[0]]
|
|
32
|
+
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
|
@@ -68,6 +68,8 @@ _import_structure = {
|
|
|
68
68
|
"RBLNDPTForDepthEstimationConfig",
|
|
69
69
|
"RBLNExaoneForCausalLM",
|
|
70
70
|
"RBLNExaoneForCausalLMConfig",
|
|
71
|
+
"RBLNGemmaModel",
|
|
72
|
+
"RBLNGemmaModelConfig",
|
|
71
73
|
"RBLNGemma3ForCausalLM",
|
|
72
74
|
"RBLNGemma3ForCausalLMConfig",
|
|
73
75
|
"RBLNGemma3ForConditionalGeneration",
|
|
@@ -76,26 +78,44 @@ _import_structure = {
|
|
|
76
78
|
"RBLNGemmaForCausalLMConfig",
|
|
77
79
|
"RBLNGPT2LMHeadModel",
|
|
78
80
|
"RBLNGPT2LMHeadModelConfig",
|
|
81
|
+
"RBLNGPT2Model",
|
|
82
|
+
"RBLNGPT2ModelConfig",
|
|
79
83
|
"RBLNIdefics3ForConditionalGeneration",
|
|
80
84
|
"RBLNIdefics3ForConditionalGenerationConfig",
|
|
81
85
|
"RBLNIdefics3VisionTransformer",
|
|
82
86
|
"RBLNIdefics3VisionTransformerConfig",
|
|
83
87
|
"RBLNLlamaForCausalLM",
|
|
84
88
|
"RBLNLlamaForCausalLMConfig",
|
|
89
|
+
"RBLNLlamaModel",
|
|
90
|
+
"RBLNLlamaModelConfig",
|
|
91
|
+
"RBLNOPTForCausalLM",
|
|
92
|
+
"RBLNOPTForCausalLMConfig",
|
|
93
|
+
"RBLNPegasusForConditionalGeneration",
|
|
94
|
+
"RBLNPegasusForConditionalGenerationConfig",
|
|
95
|
+
"RBLNPegasusModel",
|
|
96
|
+
"RBLNPegasusModelConfig",
|
|
85
97
|
"RBLNLlavaNextForConditionalGeneration",
|
|
86
98
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
|
87
99
|
"RBLNMidmLMHeadModel",
|
|
88
100
|
"RBLNMidmLMHeadModelConfig",
|
|
89
101
|
"RBLNMistralForCausalLM",
|
|
90
102
|
"RBLNMistralForCausalLMConfig",
|
|
103
|
+
"RBLNMistralModel",
|
|
104
|
+
"RBLNMistralModelConfig",
|
|
91
105
|
"RBLNOPTForCausalLM",
|
|
92
106
|
"RBLNOPTForCausalLMConfig",
|
|
107
|
+
"RBLNOPTModel",
|
|
108
|
+
"RBLNOPTModelConfig",
|
|
93
109
|
"RBLNPhiForCausalLM",
|
|
94
110
|
"RBLNPhiForCausalLMConfig",
|
|
111
|
+
"RBLNPhiModel",
|
|
112
|
+
"RBLNPhiModelConfig",
|
|
95
113
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
|
96
114
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
97
115
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
98
116
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
117
|
+
"RBLNQwen2Model",
|
|
118
|
+
"RBLNQwen2ModelConfig",
|
|
99
119
|
"RBLNQwen2ForCausalLM",
|
|
100
120
|
"RBLNQwen2ForCausalLMConfig",
|
|
101
121
|
"RBLNQwen3ForCausalLM",
|
|
@@ -170,6 +190,8 @@ if TYPE_CHECKING:
|
|
|
170
190
|
RBLNCLIPVisionModelConfig,
|
|
171
191
|
RBLNCLIPVisionModelWithProjection,
|
|
172
192
|
RBLNCLIPVisionModelWithProjectionConfig,
|
|
193
|
+
RBLNColPaliForRetrieval,
|
|
194
|
+
RBLNColPaliForRetrievalConfig,
|
|
173
195
|
RBLNDecoderOnlyModelForCausalLM,
|
|
174
196
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
175
197
|
RBLNDistilBertForQuestionAnswering,
|
|
@@ -184,30 +206,48 @@ if TYPE_CHECKING:
|
|
|
184
206
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
185
207
|
RBLNGemmaForCausalLM,
|
|
186
208
|
RBLNGemmaForCausalLMConfig,
|
|
209
|
+
RBLNGemmaModel,
|
|
210
|
+
RBLNGemmaModelConfig,
|
|
187
211
|
RBLNGPT2LMHeadModel,
|
|
188
212
|
RBLNGPT2LMHeadModelConfig,
|
|
213
|
+
RBLNGPT2Model,
|
|
214
|
+
RBLNGPT2ModelConfig,
|
|
189
215
|
RBLNIdefics3ForConditionalGeneration,
|
|
190
216
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
191
217
|
RBLNIdefics3VisionTransformer,
|
|
192
218
|
RBLNIdefics3VisionTransformerConfig,
|
|
193
219
|
RBLNLlamaForCausalLM,
|
|
194
220
|
RBLNLlamaForCausalLMConfig,
|
|
221
|
+
RBLNLlamaModel,
|
|
222
|
+
RBLNLlamaModelConfig,
|
|
195
223
|
RBLNLlavaNextForConditionalGeneration,
|
|
196
224
|
RBLNLlavaNextForConditionalGenerationConfig,
|
|
197
225
|
RBLNMidmLMHeadModel,
|
|
198
226
|
RBLNMidmLMHeadModelConfig,
|
|
199
227
|
RBLNMistralForCausalLM,
|
|
200
228
|
RBLNMistralForCausalLMConfig,
|
|
229
|
+
RBLNMistralModel,
|
|
230
|
+
RBLNMistralModelConfig,
|
|
201
231
|
RBLNOPTForCausalLM,
|
|
202
232
|
RBLNOPTForCausalLMConfig,
|
|
233
|
+
RBLNOPTModel,
|
|
234
|
+
RBLNOPTModelConfig,
|
|
235
|
+
RBLNPegasusForConditionalGeneration,
|
|
236
|
+
RBLNPegasusForConditionalGenerationConfig,
|
|
237
|
+
RBLNPegasusModel,
|
|
238
|
+
RBLNPegasusModelConfig,
|
|
203
239
|
RBLNPhiForCausalLM,
|
|
204
240
|
RBLNPhiForCausalLMConfig,
|
|
241
|
+
RBLNPhiModel,
|
|
242
|
+
RBLNPhiModelConfig,
|
|
205
243
|
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
|
206
244
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
207
245
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
208
246
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
209
247
|
RBLNQwen2ForCausalLM,
|
|
210
248
|
RBLNQwen2ForCausalLMConfig,
|
|
249
|
+
RBLNQwen2Model,
|
|
250
|
+
RBLNQwen2ModelConfig,
|
|
211
251
|
RBLNQwen3ForCausalLM,
|
|
212
252
|
RBLNQwen3ForCausalLMConfig,
|
|
213
253
|
RBLNQwen3Model,
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
from optimum.rbln.transformers.models.decoderonly.configuration_decoderonly import (
|
|
5
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from ..utils.logging import get_logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
logger = get_logger()
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from rebel import RBLNCompiledModel
|
|
15
|
+
from transformers import PretrainedConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
|
19
|
+
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
|
20
|
+
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
|
21
|
+
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
22
|
+
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
23
|
+
MAX_SLIDING_WINDOW_SIZE = 32_768
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def set_default_values(
|
|
27
|
+
attn_impl: Optional[str] = None,
|
|
28
|
+
kvcache_partition_len: Optional[int] = None,
|
|
29
|
+
kvcache_block_size: Optional[int] = None,
|
|
30
|
+
max_seq_len: Optional[int] = None,
|
|
31
|
+
) -> Tuple[str, int, int]:
|
|
32
|
+
if attn_impl is None:
|
|
33
|
+
attn_impl = "eager"
|
|
34
|
+
|
|
35
|
+
if kvcache_partition_len is not None:
|
|
36
|
+
if attn_impl == "eager":
|
|
37
|
+
attn_impl = "flash_attn"
|
|
38
|
+
logger.warning(
|
|
39
|
+
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
|
40
|
+
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
|
41
|
+
"`attn_impl` has been automatically switched to 'flash_attn'."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
|
45
|
+
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
46
|
+
|
|
47
|
+
if kvcache_block_size is None:
|
|
48
|
+
if attn_impl == "eager":
|
|
49
|
+
kvcache_block_size = max_seq_len
|
|
50
|
+
else:
|
|
51
|
+
kvcache_block_size = kvcache_partition_len
|
|
52
|
+
|
|
53
|
+
return attn_impl, kvcache_partition_len, kvcache_block_size
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
|
57
|
+
if attn_impl not in ["eager", "flash_attn"]:
|
|
58
|
+
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
|
59
|
+
|
|
60
|
+
## Checking Constraints...
|
|
61
|
+
# Constraint of eager attention:
|
|
62
|
+
# - `max_seq_len` <= 32k
|
|
63
|
+
|
|
64
|
+
# Constraints of flash attention:
|
|
65
|
+
# 1. `max_seq_len` should be multiple of `partition_len`.
|
|
66
|
+
# 2. 4k <= `partition_len` <= 32k.
|
|
67
|
+
# 3. `max_seq_len` should be larger then 8k.
|
|
68
|
+
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"`max_seq_len` is set to {max_seq_len}, "
|
|
71
|
+
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
|
72
|
+
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
|
73
|
+
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if attn_impl == "flash_attn":
|
|
77
|
+
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
|
80
|
+
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
|
81
|
+
)
|
|
82
|
+
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
|
85
|
+
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
|
86
|
+
f"Please provide a valid value within this range."
|
|
87
|
+
)
|
|
88
|
+
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
|
91
|
+
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
|
92
|
+
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if kvcache_block_size is not None:
|
|
96
|
+
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
99
|
+
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
|
100
|
+
)
|
|
101
|
+
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
104
|
+
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def validate_sliding_window(rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
109
|
+
if rbln_config.sliding_window > MAX_SLIDING_WINDOW_SIZE - rbln_config.prefill_chunk_size:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Sliding window size ({rbln_config.sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - rbln_config.prefill_chunk_size})"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if rbln_config.cache_impl == "sliding_window" and rbln_config.use_attention_mask:
|
|
115
|
+
raise ValueError("`use_attention_mask` must be set to False when `cache_impl` is set to 'sliding_window'.")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class RBLNDecoderOnlyFlashAttentionMixin:
|
|
119
|
+
@classmethod
|
|
120
|
+
def get_maximum_num_blocks(
|
|
121
|
+
cls,
|
|
122
|
+
config: "PretrainedConfig",
|
|
123
|
+
tensor_parallel_size: int,
|
|
124
|
+
kvcache_block_size: int,
|
|
125
|
+
nbits_per_param: Optional[int] = None,
|
|
126
|
+
n_model_params: Optional[int] = None,
|
|
127
|
+
kernel_size: Optional[int] = None,
|
|
128
|
+
buffer: Optional[int] = None,
|
|
129
|
+
num_runtimes: int = 2,
|
|
130
|
+
) -> int:
|
|
131
|
+
# We are finding max_n_blocks(x) that satisfies the following equation:
|
|
132
|
+
|
|
133
|
+
# available_dram - kernel_size - buffer
|
|
134
|
+
# - num_layers * 2 * tensor_parallel_size
|
|
135
|
+
# * align_2MB(
|
|
136
|
+
# x
|
|
137
|
+
# * block_size
|
|
138
|
+
# * align_64(head_dim)
|
|
139
|
+
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
|
140
|
+
# * 2
|
|
141
|
+
# ) > 0
|
|
142
|
+
|
|
143
|
+
# This inequality can be rewritten as follows:
|
|
144
|
+
|
|
145
|
+
# a - c * align_2MB(b * x) > 0
|
|
146
|
+
# where
|
|
147
|
+
# a = available_dram - kernel_size - buffer
|
|
148
|
+
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
149
|
+
# c = num_layers * 2 * tensor_parallel_size
|
|
150
|
+
|
|
151
|
+
# We can rewrite the inequality as follows:
|
|
152
|
+
# k > align_2MB(b*x)
|
|
153
|
+
# where
|
|
154
|
+
# k = a / c
|
|
155
|
+
|
|
156
|
+
# After that, we can derive the following equation:
|
|
157
|
+
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
|
158
|
+
|
|
159
|
+
def align(x: int, nbytes: int) -> int:
|
|
160
|
+
return int(math.ceil(x / nbytes) * nbytes)
|
|
161
|
+
|
|
162
|
+
def align_2MB(x: int) -> int:
|
|
163
|
+
return align(x, 2**21)
|
|
164
|
+
|
|
165
|
+
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
|
166
|
+
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
|
167
|
+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
|
168
|
+
vocab_size = config.vocab_size
|
|
169
|
+
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
|
170
|
+
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
|
171
|
+
|
|
172
|
+
# TODO(jongho): Update if target npu is REBEL.
|
|
173
|
+
ATOM_DRAM_NBYTES = 16 * 2**30
|
|
174
|
+
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
|
175
|
+
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
|
176
|
+
|
|
177
|
+
if kernel_size is None:
|
|
178
|
+
if n_model_params is None:
|
|
179
|
+
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
|
180
|
+
# Get estimated kernel size (approximated)
|
|
181
|
+
lm_heads_params = align(vocab_size, 64) * hidden_size
|
|
182
|
+
lm_heads_nbytes = (
|
|
183
|
+
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
|
184
|
+
)
|
|
185
|
+
params = n_model_params - lm_heads_params
|
|
186
|
+
layer_nbytes = (
|
|
187
|
+
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
|
188
|
+
* num_layers
|
|
189
|
+
* tensor_parallel_size
|
|
190
|
+
)
|
|
191
|
+
kernel_size = layer_nbytes + lm_heads_nbytes
|
|
192
|
+
elif n_model_params is not None:
|
|
193
|
+
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
|
194
|
+
|
|
195
|
+
available_dram -= kernel_size
|
|
196
|
+
|
|
197
|
+
if buffer is None:
|
|
198
|
+
# TODO: Accurate buffer estimation
|
|
199
|
+
buffer_per_runtime_per_core = 2**28 # 256MB per runtime
|
|
200
|
+
buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
|
|
201
|
+
buffer = buffer_per_core * tensor_parallel_size
|
|
202
|
+
available_dram -= buffer
|
|
203
|
+
|
|
204
|
+
b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
205
|
+
c = num_layers * 2 * tensor_parallel_size
|
|
206
|
+
k = available_dram / c
|
|
207
|
+
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
|
208
|
+
|
|
209
|
+
return max_n_blocks
|
|
210
|
+
|
|
211
|
+
@classmethod
|
|
212
|
+
def maybe_suggest_kvcache_num_blocks(
|
|
213
|
+
cls,
|
|
214
|
+
compiled_models: Dict[str, "RBLNCompiledModel"],
|
|
215
|
+
model_config: "PretrainedConfig",
|
|
216
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
217
|
+
) -> None:
|
|
218
|
+
# Get the actual memory allocation of each node by key
|
|
219
|
+
alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
|
|
220
|
+
alloc_memory_by_key: Dict[str, int] = {
|
|
221
|
+
key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
|
|
222
|
+
}
|
|
223
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
224
|
+
for key, memory_per_node in (
|
|
225
|
+
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
|
226
|
+
):
|
|
227
|
+
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
228
|
+
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
229
|
+
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
230
|
+
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
|
231
|
+
|
|
232
|
+
# Get the maximum number of blocks that can be allocated
|
|
233
|
+
buffer = sum(alloc_memory_by_key.values())
|
|
234
|
+
max_num_blocks = cls.get_maximum_num_blocks(
|
|
235
|
+
config=model_config,
|
|
236
|
+
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
|
237
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
238
|
+
kernel_size=kernel_size,
|
|
239
|
+
buffer=buffer,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Since our estimation logic is not always accurate,
|
|
243
|
+
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
|
244
|
+
# If the memory is not enough, the model will fail to compile.
|
|
245
|
+
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
|
246
|
+
logger.warning(
|
|
247
|
+
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
|
248
|
+
"Our analysis indicates that additional memory is available for more blocks. "
|
|
249
|
+
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
|
250
|
+
"Please be advised that our memory estimation algorithm has limitations, "
|
|
251
|
+
"and increasing this value may not guarantee successful model compilation."
|
|
252
|
+
)
|
|
@@ -92,27 +92,38 @@ _import_structure = {
|
|
|
92
92
|
"RBLNDPTForDepthEstimationConfig",
|
|
93
93
|
],
|
|
94
94
|
"exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
|
|
95
|
-
"gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
|
|
95
|
+
"gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig", "RBLNGemmaModel", "RBLNGemmaModelConfig"],
|
|
96
96
|
"gemma3": [
|
|
97
97
|
"RBLNGemma3ForCausalLM",
|
|
98
98
|
"RBLNGemma3ForCausalLMConfig",
|
|
99
99
|
"RBLNGemma3ForConditionalGeneration",
|
|
100
100
|
"RBLNGemma3ForConditionalGenerationConfig",
|
|
101
101
|
],
|
|
102
|
-
"gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
|
|
102
|
+
"gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig", "RBLNGPT2Model", "RBLNGPT2ModelConfig"],
|
|
103
103
|
"idefics3": [
|
|
104
104
|
"RBLNIdefics3VisionTransformer",
|
|
105
105
|
"RBLNIdefics3ForConditionalGeneration",
|
|
106
106
|
"RBLNIdefics3ForConditionalGenerationConfig",
|
|
107
107
|
"RBLNIdefics3VisionTransformerConfig",
|
|
108
108
|
],
|
|
109
|
-
"llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig"],
|
|
110
|
-
"opt": ["RBLNOPTForCausalLM", "RBLNOPTForCausalLMConfig"],
|
|
109
|
+
"llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig", "RBLNLlamaModel", "RBLNLlamaModelConfig"],
|
|
110
|
+
"opt": ["RBLNOPTForCausalLM", "RBLNOPTForCausalLMConfig", "RBLNOPTModel", "RBLNOPTModelConfig"],
|
|
111
|
+
"pegasus": [
|
|
112
|
+
"RBLNPegasusForConditionalGeneration",
|
|
113
|
+
"RBLNPegasusModel",
|
|
114
|
+
"RBLNPegasusForConditionalGenerationConfig",
|
|
115
|
+
"RBLNPegasusModelConfig",
|
|
116
|
+
],
|
|
111
117
|
"llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
|
|
112
118
|
"midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
|
|
113
|
-
"mistral": [
|
|
114
|
-
|
|
115
|
-
|
|
119
|
+
"mistral": [
|
|
120
|
+
"RBLNMistralForCausalLM",
|
|
121
|
+
"RBLNMistralForCausalLMConfig",
|
|
122
|
+
"RBLNMistralModel",
|
|
123
|
+
"RBLNMistralModelConfig",
|
|
124
|
+
],
|
|
125
|
+
"phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig", "RBLNPhiModel", "RBLNPhiModelConfig"],
|
|
126
|
+
"qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig", "RBLNQwen2Model", "RBLNQwen2ModelConfig"],
|
|
116
127
|
"qwen3": ["RBLNQwen3ForCausalLM", "RBLNQwen3ForCausalLMConfig", "RBLNQwen3Model", "RBLNQwen3ModelConfig"],
|
|
117
128
|
"resnet": ["RBLNResNetForImageClassification", "RBLNResNetForImageClassificationConfig"],
|
|
118
129
|
"roberta": [
|
|
@@ -215,27 +226,33 @@ if TYPE_CHECKING:
|
|
|
215
226
|
RBLNDPTForDepthEstimationConfig,
|
|
216
227
|
)
|
|
217
228
|
from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
|
|
218
|
-
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
|
|
229
|
+
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
|
|
219
230
|
from .gemma3 import (
|
|
220
231
|
RBLNGemma3ForCausalLM,
|
|
221
232
|
RBLNGemma3ForCausalLMConfig,
|
|
222
233
|
RBLNGemma3ForConditionalGeneration,
|
|
223
234
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
224
235
|
)
|
|
225
|
-
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
|
|
236
|
+
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
|
|
226
237
|
from .idefics3 import (
|
|
227
238
|
RBLNIdefics3ForConditionalGeneration,
|
|
228
239
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
229
240
|
RBLNIdefics3VisionTransformer,
|
|
230
241
|
RBLNIdefics3VisionTransformerConfig,
|
|
231
242
|
)
|
|
232
|
-
from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
|
|
243
|
+
from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig, RBLNLlamaModel, RBLNLlamaModelConfig
|
|
233
244
|
from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
|
|
234
245
|
from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
|
|
235
|
-
from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig
|
|
236
|
-
from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig
|
|
237
|
-
from .
|
|
238
|
-
|
|
246
|
+
from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig, RBLNMistralModel, RBLNMistralModelConfig
|
|
247
|
+
from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig, RBLNOPTModel, RBLNOPTModelConfig
|
|
248
|
+
from .pegasus import (
|
|
249
|
+
RBLNPegasusForConditionalGeneration,
|
|
250
|
+
RBLNPegasusForConditionalGenerationConfig,
|
|
251
|
+
RBLNPegasusModel,
|
|
252
|
+
RBLNPegasusModelConfig,
|
|
253
|
+
)
|
|
254
|
+
from .phi import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig, RBLNPhiModel, RBLNPhiModelConfig
|
|
255
|
+
from .qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2ForCausalLMConfig, RBLNQwen2Model, RBLNQwen2ModelConfig
|
|
239
256
|
from .qwen2_5_vl import (
|
|
240
257
|
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
|
241
258
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
@@ -22,5 +22,5 @@ from ....ops import (
|
|
|
22
22
|
paged_flash_causal_attn_decode,
|
|
23
23
|
paged_flash_causal_attn_prefill,
|
|
24
24
|
)
|
|
25
|
-
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
26
|
-
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
25
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
26
|
+
from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
|