optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +4 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +48 -0
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/models/__init__.py +35 -14
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
- optimum/rbln/utils/depreacate_utils.py +16 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -80,12 +80,16 @@ _import_structure = {
|
|
|
80
80
|
"RBLNDPTForDepthEstimationConfig",
|
|
81
81
|
"RBLNExaoneForCausalLM",
|
|
82
82
|
"RBLNExaoneForCausalLMConfig",
|
|
83
|
+
"RBLNGemmaModel",
|
|
84
|
+
"RBLNGemmaModelConfig",
|
|
83
85
|
"RBLNGemmaForCausalLM",
|
|
84
86
|
"RBLNGemmaForCausalLMConfig",
|
|
85
87
|
"RBLNGemma3ForCausalLM",
|
|
86
88
|
"RBLNGemma3ForCausalLMConfig",
|
|
87
89
|
"RBLNGemma3ForConditionalGeneration",
|
|
88
90
|
"RBLNGemma3ForConditionalGenerationConfig",
|
|
91
|
+
"RBLNGPT2Model",
|
|
92
|
+
"RBLNGPT2ModelConfig",
|
|
89
93
|
"RBLNGPT2LMHeadModel",
|
|
90
94
|
"RBLNGPT2LMHeadModelConfig",
|
|
91
95
|
"RBLNIdefics3VisionTransformer",
|
|
@@ -94,22 +98,40 @@ _import_structure = {
|
|
|
94
98
|
"RBLNIdefics3VisionTransformerConfig",
|
|
95
99
|
"RBLNLlamaForCausalLM",
|
|
96
100
|
"RBLNLlamaForCausalLMConfig",
|
|
101
|
+
"RBLNLlamaModel",
|
|
102
|
+
"RBLNLlamaModelConfig",
|
|
97
103
|
"RBLNOPTForCausalLM",
|
|
98
104
|
"RBLNOPTForCausalLMConfig",
|
|
105
|
+
"RBLNLlavaForConditionalGeneration",
|
|
106
|
+
"RBLNLlavaForConditionalGenerationConfig",
|
|
99
107
|
"RBLNLlavaNextForConditionalGeneration",
|
|
100
108
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
|
101
109
|
"RBLNMidmLMHeadModel",
|
|
102
110
|
"RBLNMidmLMHeadModelConfig",
|
|
111
|
+
"RBLNMistralModel",
|
|
112
|
+
"RBLNMistralModelConfig",
|
|
103
113
|
"RBLNMistralForCausalLM",
|
|
104
114
|
"RBLNMistralForCausalLMConfig",
|
|
115
|
+
"RBLNOPTModel",
|
|
116
|
+
"RBLNOPTModelConfig",
|
|
117
|
+
"RBLNPegasusForConditionalGeneration",
|
|
118
|
+
"RBLNPegasusForConditionalGenerationConfig",
|
|
119
|
+
"RBLNPegasusModel",
|
|
120
|
+
"RBLNPegasusModelConfig",
|
|
105
121
|
"RBLNPhiForCausalLM",
|
|
106
122
|
"RBLNPhiForCausalLMConfig",
|
|
123
|
+
"RBLNPixtralVisionModel",
|
|
124
|
+
"RBLNPixtralVisionModelConfig",
|
|
125
|
+
"RBLNPhiModel",
|
|
126
|
+
"RBLNPhiModelConfig",
|
|
107
127
|
"RBLNQwen2ForCausalLM",
|
|
108
128
|
"RBLNQwen2ForCausalLMConfig",
|
|
109
129
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
|
110
130
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
111
131
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
112
132
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
133
|
+
"RBLNQwen2Model",
|
|
134
|
+
"RBLNQwen2ModelConfig",
|
|
113
135
|
"RBLNQwen3ForCausalLM",
|
|
114
136
|
"RBLNQwen3ForCausalLMConfig",
|
|
115
137
|
"RBLNQwen3Model",
|
|
@@ -337,30 +359,52 @@ if TYPE_CHECKING:
|
|
|
337
359
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
338
360
|
RBLNGemmaForCausalLM,
|
|
339
361
|
RBLNGemmaForCausalLMConfig,
|
|
362
|
+
RBLNGemmaModel,
|
|
363
|
+
RBLNGemmaModelConfig,
|
|
340
364
|
RBLNGPT2LMHeadModel,
|
|
341
365
|
RBLNGPT2LMHeadModelConfig,
|
|
366
|
+
RBLNGPT2Model,
|
|
367
|
+
RBLNGPT2ModelConfig,
|
|
342
368
|
RBLNIdefics3ForConditionalGeneration,
|
|
343
369
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
344
370
|
RBLNIdefics3VisionTransformer,
|
|
345
371
|
RBLNIdefics3VisionTransformerConfig,
|
|
346
372
|
RBLNLlamaForCausalLM,
|
|
347
373
|
RBLNLlamaForCausalLMConfig,
|
|
374
|
+
RBLNLlamaModel,
|
|
375
|
+
RBLNLlamaModelConfig,
|
|
376
|
+
RBLNLlavaForConditionalGeneration,
|
|
377
|
+
RBLNLlavaForConditionalGenerationConfig,
|
|
348
378
|
RBLNLlavaNextForConditionalGeneration,
|
|
349
379
|
RBLNLlavaNextForConditionalGenerationConfig,
|
|
350
380
|
RBLNMidmLMHeadModel,
|
|
351
381
|
RBLNMidmLMHeadModelConfig,
|
|
352
382
|
RBLNMistralForCausalLM,
|
|
353
383
|
RBLNMistralForCausalLMConfig,
|
|
384
|
+
RBLNMistralModel,
|
|
385
|
+
RBLNMistralModelConfig,
|
|
354
386
|
RBLNOPTForCausalLM,
|
|
355
387
|
RBLNOPTForCausalLMConfig,
|
|
388
|
+
RBLNOPTModel,
|
|
389
|
+
RBLNOPTModelConfig,
|
|
390
|
+
RBLNPegasusForConditionalGeneration,
|
|
391
|
+
RBLNPegasusForConditionalGenerationConfig,
|
|
392
|
+
RBLNPegasusModel,
|
|
393
|
+
RBLNPegasusModelConfig,
|
|
356
394
|
RBLNPhiForCausalLM,
|
|
357
395
|
RBLNPhiForCausalLMConfig,
|
|
396
|
+
RBLNPhiModel,
|
|
397
|
+
RBLNPhiModelConfig,
|
|
398
|
+
RBLNPixtralVisionModel,
|
|
399
|
+
RBLNPixtralVisionModelConfig,
|
|
358
400
|
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
|
359
401
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
360
402
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
361
403
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
362
404
|
RBLNQwen2ForCausalLM,
|
|
363
405
|
RBLNQwen2ForCausalLMConfig,
|
|
406
|
+
RBLNQwen2Model,
|
|
407
|
+
RBLNQwen2ModelConfig,
|
|
364
408
|
RBLNQwen3ForCausalLM,
|
|
365
409
|
RBLNQwen3ForCausalLMConfig,
|
|
366
410
|
RBLNQwen3Model,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '0.8.
|
|
21
|
-
__version_tuple__ = version_tuple = (0, 8, 2, '
|
|
20
|
+
__version__ = version = '0.8.2a6'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 8, 2, 'a6')
|
|
@@ -23,6 +23,7 @@ import numpy as np
|
|
|
23
23
|
import torch
|
|
24
24
|
|
|
25
25
|
from .__version__ import __version__
|
|
26
|
+
from .utils.depreacate_utils import warn_deprecated_npu
|
|
26
27
|
from .utils.logging import get_logger
|
|
27
28
|
from .utils.runtime_utils import ContextRblnConfig
|
|
28
29
|
|
|
@@ -675,6 +676,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
675
676
|
compile_cfg.npu = self.npu
|
|
676
677
|
compile_cfg.tensor_parallel_size = self.tensor_parallel_size
|
|
677
678
|
|
|
679
|
+
target_npu = self.npu or next((cfg.npu for cfg in self._compile_cfgs if cfg.npu is not None), None)
|
|
680
|
+
warn_deprecated_npu(target_npu)
|
|
681
|
+
|
|
678
682
|
def freeze(self):
|
|
679
683
|
if self._frozen:
|
|
680
684
|
raise RuntimeError(f"`{self.__class__.__name__}` is already frozen.")
|
|
@@ -22,3 +22,8 @@ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tens
|
|
|
22
22
|
# This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
|
|
23
23
|
# The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
|
|
24
24
|
return torch.empty_like(cache)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@rbln_cache_update.register_fake
|
|
28
|
+
def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
|
|
29
|
+
return torch.empty_like(cache)
|
optimum/rbln/ops/linear.py
CHANGED
|
@@ -23,3 +23,10 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
|
|
|
23
23
|
output_shape = list(input.shape[:-1])
|
|
24
24
|
output_shape += [weight.shape[0]]
|
|
25
25
|
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@linear.register_fake
|
|
29
|
+
def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
|
|
30
|
+
output_shape = list(input.shape[:-1])
|
|
31
|
+
output_shape += [weight.shape[0]]
|
|
32
|
+
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
|
@@ -68,6 +68,8 @@ _import_structure = {
|
|
|
68
68
|
"RBLNDPTForDepthEstimationConfig",
|
|
69
69
|
"RBLNExaoneForCausalLM",
|
|
70
70
|
"RBLNExaoneForCausalLMConfig",
|
|
71
|
+
"RBLNGemmaModel",
|
|
72
|
+
"RBLNGemmaModelConfig",
|
|
71
73
|
"RBLNGemma3ForCausalLM",
|
|
72
74
|
"RBLNGemma3ForCausalLMConfig",
|
|
73
75
|
"RBLNGemma3ForConditionalGeneration",
|
|
@@ -76,26 +78,48 @@ _import_structure = {
|
|
|
76
78
|
"RBLNGemmaForCausalLMConfig",
|
|
77
79
|
"RBLNGPT2LMHeadModel",
|
|
78
80
|
"RBLNGPT2LMHeadModelConfig",
|
|
81
|
+
"RBLNGPT2Model",
|
|
82
|
+
"RBLNGPT2ModelConfig",
|
|
79
83
|
"RBLNIdefics3ForConditionalGeneration",
|
|
80
84
|
"RBLNIdefics3ForConditionalGenerationConfig",
|
|
81
85
|
"RBLNIdefics3VisionTransformer",
|
|
82
86
|
"RBLNIdefics3VisionTransformerConfig",
|
|
83
87
|
"RBLNLlamaForCausalLM",
|
|
84
88
|
"RBLNLlamaForCausalLMConfig",
|
|
89
|
+
"RBLNLlavaForConditionalGeneration",
|
|
90
|
+
"RBLNLlavaForConditionalGenerationConfig",
|
|
91
|
+
"RBLNLlamaModel",
|
|
92
|
+
"RBLNLlamaModelConfig",
|
|
93
|
+
"RBLNOPTForCausalLM",
|
|
94
|
+
"RBLNOPTForCausalLMConfig",
|
|
95
|
+
"RBLNPegasusForConditionalGeneration",
|
|
96
|
+
"RBLNPegasusForConditionalGenerationConfig",
|
|
97
|
+
"RBLNPegasusModel",
|
|
98
|
+
"RBLNPegasusModelConfig",
|
|
85
99
|
"RBLNLlavaNextForConditionalGeneration",
|
|
86
100
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
|
87
101
|
"RBLNMidmLMHeadModel",
|
|
88
102
|
"RBLNMidmLMHeadModelConfig",
|
|
89
103
|
"RBLNMistralForCausalLM",
|
|
90
104
|
"RBLNMistralForCausalLMConfig",
|
|
105
|
+
"RBLNMistralModel",
|
|
106
|
+
"RBLNMistralModelConfig",
|
|
91
107
|
"RBLNOPTForCausalLM",
|
|
92
108
|
"RBLNOPTForCausalLMConfig",
|
|
109
|
+
"RBLNOPTModel",
|
|
110
|
+
"RBLNOPTModelConfig",
|
|
93
111
|
"RBLNPhiForCausalLM",
|
|
94
112
|
"RBLNPhiForCausalLMConfig",
|
|
113
|
+
"RBLNPixtralVisionModelConfig",
|
|
114
|
+
"RBLNPixtralVisionModel",
|
|
115
|
+
"RBLNPhiModel",
|
|
116
|
+
"RBLNPhiModelConfig",
|
|
95
117
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
|
96
118
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
97
119
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
98
120
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
121
|
+
"RBLNQwen2Model",
|
|
122
|
+
"RBLNQwen2ModelConfig",
|
|
99
123
|
"RBLNQwen2ForCausalLM",
|
|
100
124
|
"RBLNQwen2ForCausalLMConfig",
|
|
101
125
|
"RBLNQwen3ForCausalLM",
|
|
@@ -170,6 +194,8 @@ if TYPE_CHECKING:
|
|
|
170
194
|
RBLNCLIPVisionModelConfig,
|
|
171
195
|
RBLNCLIPVisionModelWithProjection,
|
|
172
196
|
RBLNCLIPVisionModelWithProjectionConfig,
|
|
197
|
+
RBLNColPaliForRetrieval,
|
|
198
|
+
RBLNColPaliForRetrievalConfig,
|
|
173
199
|
RBLNDecoderOnlyModelForCausalLM,
|
|
174
200
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
175
201
|
RBLNDistilBertForQuestionAnswering,
|
|
@@ -184,30 +210,52 @@ if TYPE_CHECKING:
|
|
|
184
210
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
185
211
|
RBLNGemmaForCausalLM,
|
|
186
212
|
RBLNGemmaForCausalLMConfig,
|
|
213
|
+
RBLNGemmaModel,
|
|
214
|
+
RBLNGemmaModelConfig,
|
|
187
215
|
RBLNGPT2LMHeadModel,
|
|
188
216
|
RBLNGPT2LMHeadModelConfig,
|
|
217
|
+
RBLNGPT2Model,
|
|
218
|
+
RBLNGPT2ModelConfig,
|
|
189
219
|
RBLNIdefics3ForConditionalGeneration,
|
|
190
220
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
191
221
|
RBLNIdefics3VisionTransformer,
|
|
192
222
|
RBLNIdefics3VisionTransformerConfig,
|
|
193
223
|
RBLNLlamaForCausalLM,
|
|
194
224
|
RBLNLlamaForCausalLMConfig,
|
|
225
|
+
RBLNLlamaModel,
|
|
226
|
+
RBLNLlamaModelConfig,
|
|
227
|
+
RBLNLlavaForConditionalGeneration,
|
|
228
|
+
RBLNLlavaForConditionalGenerationConfig,
|
|
195
229
|
RBLNLlavaNextForConditionalGeneration,
|
|
196
230
|
RBLNLlavaNextForConditionalGenerationConfig,
|
|
197
231
|
RBLNMidmLMHeadModel,
|
|
198
232
|
RBLNMidmLMHeadModelConfig,
|
|
199
233
|
RBLNMistralForCausalLM,
|
|
200
234
|
RBLNMistralForCausalLMConfig,
|
|
235
|
+
RBLNMistralModel,
|
|
236
|
+
RBLNMistralModelConfig,
|
|
201
237
|
RBLNOPTForCausalLM,
|
|
202
238
|
RBLNOPTForCausalLMConfig,
|
|
239
|
+
RBLNOPTModel,
|
|
240
|
+
RBLNOPTModelConfig,
|
|
241
|
+
RBLNPegasusForConditionalGeneration,
|
|
242
|
+
RBLNPegasusForConditionalGenerationConfig,
|
|
243
|
+
RBLNPegasusModel,
|
|
244
|
+
RBLNPegasusModelConfig,
|
|
203
245
|
RBLNPhiForCausalLM,
|
|
204
246
|
RBLNPhiForCausalLMConfig,
|
|
247
|
+
RBLNPhiModel,
|
|
248
|
+
RBLNPhiModelConfig,
|
|
249
|
+
RBLNPixtralVisionModel,
|
|
250
|
+
RBLNPixtralVisionModelConfig,
|
|
205
251
|
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
|
206
252
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
207
253
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
208
254
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
209
255
|
RBLNQwen2ForCausalLM,
|
|
210
256
|
RBLNQwen2ForCausalLMConfig,
|
|
257
|
+
RBLNQwen2Model,
|
|
258
|
+
RBLNQwen2ModelConfig,
|
|
211
259
|
RBLNQwen3ForCausalLM,
|
|
212
260
|
RBLNQwen3ForCausalLMConfig,
|
|
213
261
|
RBLNQwen3Model,
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
from optimum.rbln.transformers.models.decoderonly.configuration_decoderonly import (
|
|
5
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from ..utils.logging import get_logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
logger = get_logger()
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from rebel import RBLNCompiledModel
|
|
15
|
+
from transformers import PretrainedConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
|
19
|
+
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
|
20
|
+
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
|
21
|
+
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
22
|
+
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
23
|
+
MAX_SLIDING_WINDOW_SIZE = 32_768
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def set_default_values(
|
|
27
|
+
attn_impl: Optional[str] = None,
|
|
28
|
+
kvcache_partition_len: Optional[int] = None,
|
|
29
|
+
kvcache_block_size: Optional[int] = None,
|
|
30
|
+
max_seq_len: Optional[int] = None,
|
|
31
|
+
) -> Tuple[str, int, int]:
|
|
32
|
+
if attn_impl is None:
|
|
33
|
+
attn_impl = "eager"
|
|
34
|
+
|
|
35
|
+
if kvcache_partition_len is not None:
|
|
36
|
+
if attn_impl == "eager":
|
|
37
|
+
attn_impl = "flash_attn"
|
|
38
|
+
logger.warning(
|
|
39
|
+
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
|
40
|
+
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
|
41
|
+
"`attn_impl` has been automatically switched to 'flash_attn'."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
|
45
|
+
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
46
|
+
|
|
47
|
+
if kvcache_block_size is None:
|
|
48
|
+
if attn_impl == "eager":
|
|
49
|
+
kvcache_block_size = max_seq_len
|
|
50
|
+
else:
|
|
51
|
+
kvcache_block_size = kvcache_partition_len
|
|
52
|
+
|
|
53
|
+
return attn_impl, kvcache_partition_len, kvcache_block_size
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
|
57
|
+
if attn_impl not in ["eager", "flash_attn"]:
|
|
58
|
+
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
|
59
|
+
|
|
60
|
+
## Checking Constraints...
|
|
61
|
+
# Constraint of eager attention:
|
|
62
|
+
# - `max_seq_len` <= 32k
|
|
63
|
+
|
|
64
|
+
# Constraints of flash attention:
|
|
65
|
+
# 1. `max_seq_len` should be multiple of `partition_len`.
|
|
66
|
+
# 2. 4k <= `partition_len` <= 32k.
|
|
67
|
+
# 3. `max_seq_len` should be larger then 8k.
|
|
68
|
+
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"`max_seq_len` is set to {max_seq_len}, "
|
|
71
|
+
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
|
72
|
+
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
|
73
|
+
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if attn_impl == "flash_attn":
|
|
77
|
+
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
|
80
|
+
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
|
81
|
+
)
|
|
82
|
+
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
|
85
|
+
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
|
86
|
+
f"Please provide a valid value within this range."
|
|
87
|
+
)
|
|
88
|
+
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
|
91
|
+
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
|
92
|
+
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if kvcache_block_size is not None:
|
|
96
|
+
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
99
|
+
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
|
100
|
+
)
|
|
101
|
+
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
104
|
+
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def validate_sliding_window(rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
109
|
+
if rbln_config.sliding_window > MAX_SLIDING_WINDOW_SIZE - rbln_config.prefill_chunk_size:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Sliding window size ({rbln_config.sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - rbln_config.prefill_chunk_size})"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if rbln_config.cache_impl == "sliding_window" and rbln_config.use_attention_mask:
|
|
115
|
+
raise ValueError("`use_attention_mask` must be set to False when `cache_impl` is set to 'sliding_window'.")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class RBLNDecoderOnlyFlashAttentionMixin:
|
|
119
|
+
@classmethod
|
|
120
|
+
def get_maximum_num_blocks(
|
|
121
|
+
cls,
|
|
122
|
+
config: "PretrainedConfig",
|
|
123
|
+
tensor_parallel_size: int,
|
|
124
|
+
kvcache_block_size: int,
|
|
125
|
+
nbits_per_param: Optional[int] = None,
|
|
126
|
+
n_model_params: Optional[int] = None,
|
|
127
|
+
kernel_size: Optional[int] = None,
|
|
128
|
+
buffer: Optional[int] = None,
|
|
129
|
+
num_runtimes: int = 2,
|
|
130
|
+
) -> int:
|
|
131
|
+
# We are finding max_n_blocks(x) that satisfies the following equation:
|
|
132
|
+
|
|
133
|
+
# available_dram - kernel_size - buffer
|
|
134
|
+
# - num_layers * 2 * tensor_parallel_size
|
|
135
|
+
# * align_2MB(
|
|
136
|
+
# x
|
|
137
|
+
# * block_size
|
|
138
|
+
# * align_64(head_dim)
|
|
139
|
+
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
|
140
|
+
# * 2
|
|
141
|
+
# ) > 0
|
|
142
|
+
|
|
143
|
+
# This inequality can be rewritten as follows:
|
|
144
|
+
|
|
145
|
+
# a - c * align_2MB(b * x) > 0
|
|
146
|
+
# where
|
|
147
|
+
# a = available_dram - kernel_size - buffer
|
|
148
|
+
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
149
|
+
# c = num_layers * 2 * tensor_parallel_size
|
|
150
|
+
|
|
151
|
+
# We can rewrite the inequality as follows:
|
|
152
|
+
# k > align_2MB(b*x)
|
|
153
|
+
# where
|
|
154
|
+
# k = a / c
|
|
155
|
+
|
|
156
|
+
# After that, we can derive the following equation:
|
|
157
|
+
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
|
158
|
+
|
|
159
|
+
def align(x: int, nbytes: int) -> int:
|
|
160
|
+
return int(math.ceil(x / nbytes) * nbytes)
|
|
161
|
+
|
|
162
|
+
def align_2MB(x: int) -> int:
|
|
163
|
+
return align(x, 2**21)
|
|
164
|
+
|
|
165
|
+
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
|
166
|
+
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
|
167
|
+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
|
168
|
+
vocab_size = config.vocab_size
|
|
169
|
+
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
|
170
|
+
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
|
171
|
+
|
|
172
|
+
# TODO(jongho): Update if target npu is REBEL.
|
|
173
|
+
ATOM_DRAM_NBYTES = 16 * 2**30
|
|
174
|
+
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
|
175
|
+
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
|
176
|
+
|
|
177
|
+
if kernel_size is None:
|
|
178
|
+
if n_model_params is None:
|
|
179
|
+
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
|
180
|
+
# Get estimated kernel size (approximated)
|
|
181
|
+
lm_heads_params = align(vocab_size, 64) * hidden_size
|
|
182
|
+
lm_heads_nbytes = (
|
|
183
|
+
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
|
184
|
+
)
|
|
185
|
+
params = n_model_params - lm_heads_params
|
|
186
|
+
layer_nbytes = (
|
|
187
|
+
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
|
188
|
+
* num_layers
|
|
189
|
+
* tensor_parallel_size
|
|
190
|
+
)
|
|
191
|
+
kernel_size = layer_nbytes + lm_heads_nbytes
|
|
192
|
+
elif n_model_params is not None:
|
|
193
|
+
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
|
194
|
+
|
|
195
|
+
available_dram -= kernel_size
|
|
196
|
+
|
|
197
|
+
if buffer is None:
|
|
198
|
+
# TODO: Accurate buffer estimation
|
|
199
|
+
buffer_per_runtime_per_core = 2**28 # 256MB per runtime
|
|
200
|
+
buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
|
|
201
|
+
buffer = buffer_per_core * tensor_parallel_size
|
|
202
|
+
available_dram -= buffer
|
|
203
|
+
|
|
204
|
+
b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
205
|
+
c = num_layers * 2 * tensor_parallel_size
|
|
206
|
+
k = available_dram / c
|
|
207
|
+
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
|
208
|
+
|
|
209
|
+
return max_n_blocks
|
|
210
|
+
|
|
211
|
+
@classmethod
|
|
212
|
+
def maybe_suggest_kvcache_num_blocks(
|
|
213
|
+
cls,
|
|
214
|
+
compiled_models: Dict[str, "RBLNCompiledModel"],
|
|
215
|
+
model_config: "PretrainedConfig",
|
|
216
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
217
|
+
) -> None:
|
|
218
|
+
# Get the actual memory allocation of each node by key
|
|
219
|
+
alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
|
|
220
|
+
alloc_memory_by_key: Dict[str, int] = {
|
|
221
|
+
key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
|
|
222
|
+
}
|
|
223
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
224
|
+
for key, memory_per_node in (
|
|
225
|
+
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
|
226
|
+
):
|
|
227
|
+
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
228
|
+
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
229
|
+
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
230
|
+
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
|
231
|
+
|
|
232
|
+
# Get the maximum number of blocks that can be allocated
|
|
233
|
+
buffer = sum(alloc_memory_by_key.values())
|
|
234
|
+
max_num_blocks = cls.get_maximum_num_blocks(
|
|
235
|
+
config=model_config,
|
|
236
|
+
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
|
237
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
238
|
+
kernel_size=kernel_size,
|
|
239
|
+
buffer=buffer,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Since our estimation logic is not always accurate,
|
|
243
|
+
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
|
244
|
+
# If the memory is not enough, the model will fail to compile.
|
|
245
|
+
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
|
246
|
+
logger.warning(
|
|
247
|
+
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
|
248
|
+
"Our analysis indicates that additional memory is available for more blocks. "
|
|
249
|
+
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
|
250
|
+
"Please be advised that our memory estimation algorithm has limitations, "
|
|
251
|
+
"and increasing this value may not guarantee successful model compilation."
|
|
252
|
+
)
|
|
@@ -92,27 +92,40 @@ _import_structure = {
|
|
|
92
92
|
"RBLNDPTForDepthEstimationConfig",
|
|
93
93
|
],
|
|
94
94
|
"exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
|
|
95
|
-
"gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
|
|
95
|
+
"gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig", "RBLNGemmaModel", "RBLNGemmaModelConfig"],
|
|
96
96
|
"gemma3": [
|
|
97
97
|
"RBLNGemma3ForCausalLM",
|
|
98
98
|
"RBLNGemma3ForCausalLMConfig",
|
|
99
99
|
"RBLNGemma3ForConditionalGeneration",
|
|
100
100
|
"RBLNGemma3ForConditionalGenerationConfig",
|
|
101
101
|
],
|
|
102
|
-
"gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
|
|
102
|
+
"gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig", "RBLNGPT2Model", "RBLNGPT2ModelConfig"],
|
|
103
103
|
"idefics3": [
|
|
104
104
|
"RBLNIdefics3VisionTransformer",
|
|
105
105
|
"RBLNIdefics3ForConditionalGeneration",
|
|
106
106
|
"RBLNIdefics3ForConditionalGenerationConfig",
|
|
107
107
|
"RBLNIdefics3VisionTransformerConfig",
|
|
108
108
|
],
|
|
109
|
-
"
|
|
110
|
-
"
|
|
109
|
+
"llava": ["RBLNLlavaForConditionalGeneration", "RBLNLlavaForConditionalGenerationConfig"],
|
|
110
|
+
"llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig", "RBLNLlamaModel", "RBLNLlamaModelConfig"],
|
|
111
|
+
"opt": ["RBLNOPTForCausalLM", "RBLNOPTForCausalLMConfig", "RBLNOPTModel", "RBLNOPTModelConfig"],
|
|
112
|
+
"pegasus": [
|
|
113
|
+
"RBLNPegasusForConditionalGeneration",
|
|
114
|
+
"RBLNPegasusModel",
|
|
115
|
+
"RBLNPegasusForConditionalGenerationConfig",
|
|
116
|
+
"RBLNPegasusModelConfig",
|
|
117
|
+
],
|
|
111
118
|
"llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
|
|
112
119
|
"midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
|
|
113
|
-
"
|
|
114
|
-
"
|
|
115
|
-
|
|
120
|
+
"pixtral": ["RBLNPixtralVisionModel", "RBLNPixtralVisionModelConfig"],
|
|
121
|
+
"mistral": [
|
|
122
|
+
"RBLNMistralForCausalLM",
|
|
123
|
+
"RBLNMistralForCausalLMConfig",
|
|
124
|
+
"RBLNMistralModel",
|
|
125
|
+
"RBLNMistralModelConfig",
|
|
126
|
+
],
|
|
127
|
+
"phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig", "RBLNPhiModel", "RBLNPhiModelConfig"],
|
|
128
|
+
"qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig", "RBLNQwen2Model", "RBLNQwen2ModelConfig"],
|
|
116
129
|
"qwen3": ["RBLNQwen3ForCausalLM", "RBLNQwen3ForCausalLMConfig", "RBLNQwen3Model", "RBLNQwen3ModelConfig"],
|
|
117
130
|
"resnet": ["RBLNResNetForImageClassification", "RBLNResNetForImageClassificationConfig"],
|
|
118
131
|
"roberta": [
|
|
@@ -215,27 +228,35 @@ if TYPE_CHECKING:
|
|
|
215
228
|
RBLNDPTForDepthEstimationConfig,
|
|
216
229
|
)
|
|
217
230
|
from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
|
|
218
|
-
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
|
|
231
|
+
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
|
|
219
232
|
from .gemma3 import (
|
|
220
233
|
RBLNGemma3ForCausalLM,
|
|
221
234
|
RBLNGemma3ForCausalLMConfig,
|
|
222
235
|
RBLNGemma3ForConditionalGeneration,
|
|
223
236
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
224
237
|
)
|
|
225
|
-
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
|
|
238
|
+
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
|
|
226
239
|
from .idefics3 import (
|
|
227
240
|
RBLNIdefics3ForConditionalGeneration,
|
|
228
241
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
229
242
|
RBLNIdefics3VisionTransformer,
|
|
230
243
|
RBLNIdefics3VisionTransformerConfig,
|
|
231
244
|
)
|
|
232
|
-
from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
|
|
245
|
+
from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig, RBLNLlamaModel, RBLNLlamaModelConfig
|
|
246
|
+
from .llava import RBLNLlavaForConditionalGeneration, RBLNLlavaForConditionalGenerationConfig
|
|
233
247
|
from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
|
|
234
248
|
from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
|
|
235
|
-
from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig
|
|
236
|
-
from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig
|
|
237
|
-
from .
|
|
238
|
-
|
|
249
|
+
from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig, RBLNMistralModel, RBLNMistralModelConfig
|
|
250
|
+
from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig, RBLNOPTModel, RBLNOPTModelConfig
|
|
251
|
+
from .pegasus import (
|
|
252
|
+
RBLNPegasusForConditionalGeneration,
|
|
253
|
+
RBLNPegasusForConditionalGenerationConfig,
|
|
254
|
+
RBLNPegasusModel,
|
|
255
|
+
RBLNPegasusModelConfig,
|
|
256
|
+
)
|
|
257
|
+
from .phi import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig, RBLNPhiModel, RBLNPhiModelConfig
|
|
258
|
+
from .pixtral import RBLNPixtralVisionModel, RBLNPixtralVisionModelConfig
|
|
259
|
+
from .qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2ForCausalLMConfig, RBLNQwen2Model, RBLNQwen2ModelConfig
|
|
239
260
|
from .qwen2_5_vl import (
|
|
240
261
|
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
|
241
262
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
@@ -22,5 +22,5 @@ from ....ops import (
|
|
|
22
22
|
paged_flash_causal_attn_decode,
|
|
23
23
|
paged_flash_causal_attn_prefill,
|
|
24
24
|
)
|
|
25
|
-
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
26
|
-
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
25
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
26
|
+
from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
|