optimum-rbln 0.9.2a3__py3-none-any.whl → 0.9.2a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +4 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +3 -0
- optimum/rbln/modeling.py +71 -1
- optimum/rbln/transformers/__init__.py +4 -0
- optimum/rbln/transformers/modeling_generic.py +23 -1
- optimum/rbln/transformers/models/__init__.py +4 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +65 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +79 -4
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +9 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +2 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +0 -9
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/METADATA +5 -5
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/RECORD +34 -32
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -118,6 +118,8 @@ _import_structure = {
|
|
|
118
118
|
"RBLNLlavaForConditionalGenerationConfig",
|
|
119
119
|
"RBLNLlavaNextForConditionalGeneration",
|
|
120
120
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
|
121
|
+
"RBLNLoRAAdapterConfig",
|
|
122
|
+
"RBLNLoRAConfig",
|
|
121
123
|
"RBLNMidmLMHeadModel",
|
|
122
124
|
"RBLNMidmLMHeadModelConfig",
|
|
123
125
|
"RBLNMistralModel",
|
|
@@ -406,6 +408,8 @@ if TYPE_CHECKING:
|
|
|
406
408
|
RBLNLlavaForConditionalGenerationConfig,
|
|
407
409
|
RBLNLlavaNextForConditionalGeneration,
|
|
408
410
|
RBLNLlavaNextForConditionalGenerationConfig,
|
|
411
|
+
RBLNLoRAAdapterConfig,
|
|
412
|
+
RBLNLoRAConfig,
|
|
409
413
|
RBLNMidmLMHeadModel,
|
|
410
414
|
RBLNMidmLMHeadModelConfig,
|
|
411
415
|
RBLNMistralForCausalLM,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.9.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 9, 2, '
|
|
31
|
+
__version__ = version = '0.9.2a5'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 9, 2, 'a5')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -41,6 +41,9 @@ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
|
|
|
41
41
|
class RBLNSerializableConfigProtocol(Protocol):
|
|
42
42
|
def _prepare_for_serialization(self) -> Dict[str, Any]: ...
|
|
43
43
|
|
|
44
|
+
def __repr__(self) -> str:
|
|
45
|
+
return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
|
|
46
|
+
|
|
44
47
|
|
|
45
48
|
@dataclass
|
|
46
49
|
class RBLNCompileConfig:
|
optimum/rbln/modeling.py
CHANGED
|
@@ -34,6 +34,49 @@ if TYPE_CHECKING:
|
|
|
34
34
|
logger = get_logger(__name__)
|
|
35
35
|
|
|
36
36
|
|
|
37
|
+
def _get_dtype(
|
|
38
|
+
cls,
|
|
39
|
+
dtype: Optional[Union[str, torch.dtype, dict]],
|
|
40
|
+
config: PretrainedConfig,
|
|
41
|
+
) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
|
42
|
+
dtype_orig = None
|
|
43
|
+
|
|
44
|
+
if dtype is not None:
|
|
45
|
+
if isinstance(dtype, str):
|
|
46
|
+
if dtype == "auto":
|
|
47
|
+
if hasattr(config, "dtype") and config.dtype is not None:
|
|
48
|
+
dtype = config.dtype
|
|
49
|
+
else:
|
|
50
|
+
dtype = torch.get_default_dtype()
|
|
51
|
+
elif hasattr(torch, dtype):
|
|
52
|
+
dtype = getattr(torch, dtype)
|
|
53
|
+
config.dtype = dtype
|
|
54
|
+
elif isinstance(dtype, torch.dtype):
|
|
55
|
+
config.dtype = dtype
|
|
56
|
+
elif isinstance(dtype, dict):
|
|
57
|
+
for key, curr_dtype in dtype.items():
|
|
58
|
+
if hasattr(config, key):
|
|
59
|
+
value = getattr(config, key)
|
|
60
|
+
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
|
|
61
|
+
value.dtype = curr_dtype
|
|
62
|
+
# main torch dtype for modules that aren't part of any sub-config
|
|
63
|
+
dtype = dtype.get("")
|
|
64
|
+
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
|
|
65
|
+
config.dtype = dtype
|
|
66
|
+
if dtype is None:
|
|
67
|
+
dtype = torch.float32
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"Invalid dtype: {dtype}")
|
|
70
|
+
|
|
71
|
+
dtype_orig = cls._set_default_dtype(dtype)
|
|
72
|
+
else:
|
|
73
|
+
# Use default dtype
|
|
74
|
+
default_dtype = torch.get_default_dtype()
|
|
75
|
+
config.dtype = default_dtype
|
|
76
|
+
|
|
77
|
+
return config, dtype, dtype_orig
|
|
78
|
+
|
|
79
|
+
|
|
37
80
|
class RBLNModel(RBLNBaseModel):
|
|
38
81
|
@classmethod
|
|
39
82
|
def update_kwargs(cls, kwargs):
|
|
@@ -206,10 +249,37 @@ class RBLNModel(RBLNBaseModel):
|
|
|
206
249
|
trust_remote_code: bool = False,
|
|
207
250
|
# Some rbln-config should be applied before loading torch module (i.e. quantized llm)
|
|
208
251
|
rbln_config: Optional[RBLNModelConfig] = None,
|
|
252
|
+
dtype: Optional[Union[str, torch.dtype, dict]] = None,
|
|
209
253
|
**kwargs,
|
|
210
254
|
) -> "PreTrainedModel":
|
|
211
255
|
kwargs = cls.update_kwargs(kwargs)
|
|
212
|
-
|
|
256
|
+
|
|
257
|
+
hf_class = cls.get_hf_class()
|
|
258
|
+
|
|
259
|
+
if dtype is not None:
|
|
260
|
+
config = hf_class.config_class.from_pretrained(
|
|
261
|
+
model_id,
|
|
262
|
+
subfolder=subfolder,
|
|
263
|
+
revision=revision,
|
|
264
|
+
cache_dir=cache_dir,
|
|
265
|
+
use_auth_token=use_auth_token,
|
|
266
|
+
local_files_only=local_files_only,
|
|
267
|
+
force_download=force_download,
|
|
268
|
+
trust_remote_code=trust_remote_code,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
config, processed_dtype, dtype_orig = _get_dtype(
|
|
272
|
+
cls=hf_class,
|
|
273
|
+
dtype=dtype,
|
|
274
|
+
config=config,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
kwargs["torch_dtype"] = processed_dtype
|
|
278
|
+
|
|
279
|
+
if dtype_orig is not None:
|
|
280
|
+
hf_class._set_default_dtype(dtype_orig)
|
|
281
|
+
|
|
282
|
+
return hf_class.from_pretrained(
|
|
213
283
|
model_id,
|
|
214
284
|
subfolder=subfolder,
|
|
215
285
|
revision=revision,
|
|
@@ -110,6 +110,8 @@ _import_structure = {
|
|
|
110
110
|
"RBLNPegasusModelConfig",
|
|
111
111
|
"RBLNLlavaNextForConditionalGeneration",
|
|
112
112
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
|
113
|
+
"RBLNLoRAAdapterConfig",
|
|
114
|
+
"RBLNLoRAConfig",
|
|
113
115
|
"RBLNMidmLMHeadModel",
|
|
114
116
|
"RBLNMidmLMHeadModelConfig",
|
|
115
117
|
"RBLNMistralForCausalLM",
|
|
@@ -258,6 +260,8 @@ if TYPE_CHECKING:
|
|
|
258
260
|
RBLNLlavaForConditionalGenerationConfig,
|
|
259
261
|
RBLNLlavaNextForConditionalGeneration,
|
|
260
262
|
RBLNLlavaNextForConditionalGenerationConfig,
|
|
263
|
+
RBLNLoRAAdapterConfig,
|
|
264
|
+
RBLNLoRAConfig,
|
|
261
265
|
RBLNMidmLMHeadModel,
|
|
262
266
|
RBLNMidmLMHeadModelConfig,
|
|
263
267
|
RBLNMistralForCausalLM,
|
|
@@ -23,6 +23,7 @@ different model architectures.
|
|
|
23
23
|
import inspect
|
|
24
24
|
from typing import TYPE_CHECKING, Optional, Union
|
|
25
25
|
|
|
26
|
+
from torch import nn
|
|
26
27
|
from transformers import (
|
|
27
28
|
AutoModel,
|
|
28
29
|
AutoModelForAudioClassification,
|
|
@@ -57,6 +58,28 @@ class RBLNTransformerEncoder(RBLNModel):
|
|
|
57
58
|
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
58
59
|
rbln_dtype = "int64"
|
|
59
60
|
|
|
61
|
+
@classmethod
|
|
62
|
+
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
|
|
63
|
+
class TransformerEncoderWrapper(nn.Module):
|
|
64
|
+
# Parameters to disable for RBLN compilation
|
|
65
|
+
DISABLED_PARAMS = {"return_dict", "use_cache"}
|
|
66
|
+
|
|
67
|
+
def __init__(self, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.model = model
|
|
70
|
+
self.rbln_config = rbln_config
|
|
71
|
+
self._forward_signature = inspect.signature(model.forward)
|
|
72
|
+
|
|
73
|
+
def forward(self, *args, **kwargs):
|
|
74
|
+
# Disable parameters that are not compatible with RBLN compilation
|
|
75
|
+
for param_name in self.DISABLED_PARAMS:
|
|
76
|
+
if param_name in self._forward_signature.parameters:
|
|
77
|
+
kwargs[param_name] = False
|
|
78
|
+
|
|
79
|
+
return self.model(*args, **kwargs)
|
|
80
|
+
|
|
81
|
+
return TransformerEncoderWrapper(model, rbln_config).eval()
|
|
82
|
+
|
|
60
83
|
@classmethod
|
|
61
84
|
def _update_rbln_config(
|
|
62
85
|
cls,
|
|
@@ -208,7 +231,6 @@ class RBLNModelForQuestionAnswering(RBLNTransformerEncoder):
|
|
|
208
231
|
|
|
209
232
|
def _prepare_output(self, output, return_dict):
|
|
210
233
|
# Prepare QuestionAnswering specific output format.
|
|
211
|
-
|
|
212
234
|
start_logits, end_logits = output
|
|
213
235
|
|
|
214
236
|
if not return_dict:
|
|
@@ -96,6 +96,8 @@ _import_structure = {
|
|
|
96
96
|
"RBLNDecoderOnlyModel",
|
|
97
97
|
"RBLNDecoderOnlyModelForCausalLM",
|
|
98
98
|
"RBLNDecoderOnlyModelForCausalLMConfig",
|
|
99
|
+
"RBLNLoRAAdapterConfig",
|
|
100
|
+
"RBLNLoRAConfig",
|
|
99
101
|
],
|
|
100
102
|
"depth_anything": ["RBLNDepthAnythingForDepthEstimationConfig", "RBLNDepthAnythingForDepthEstimation"],
|
|
101
103
|
"dpt": [
|
|
@@ -239,6 +241,8 @@ if TYPE_CHECKING:
|
|
|
239
241
|
RBLNDecoderOnlyModelConfig,
|
|
240
242
|
RBLNDecoderOnlyModelForCausalLM,
|
|
241
243
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
244
|
+
RBLNLoRAAdapterConfig,
|
|
245
|
+
RBLNLoRAConfig,
|
|
242
246
|
)
|
|
243
247
|
from .depth_anything import RBLNDepthAnythingForDepthEstimation, RBLNDepthAnythingForDepthEstimationConfig
|
|
244
248
|
from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
|
|
@@ -31,6 +31,7 @@ from transformers.utils import logging
|
|
|
31
31
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
32
32
|
from ....modeling import RBLNModel
|
|
33
33
|
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
34
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
34
35
|
|
|
35
36
|
|
|
36
37
|
logger = logging.get_logger(__name__)
|
|
@@ -265,7 +266,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
265
266
|
)
|
|
266
267
|
|
|
267
268
|
|
|
268
|
-
class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
|
269
|
+
class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
269
270
|
"""
|
|
270
271
|
RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
271
272
|
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
@@ -433,3 +434,66 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
|
|
433
434
|
)
|
|
434
435
|
|
|
435
436
|
return inputs_embeds
|
|
437
|
+
|
|
438
|
+
@torch.no_grad()
|
|
439
|
+
def generate(
|
|
440
|
+
self,
|
|
441
|
+
pixel_values: torch.FloatTensor,
|
|
442
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
443
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
444
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
445
|
+
interpolate_pos_encoding: bool = False,
|
|
446
|
+
**generate_kwargs,
|
|
447
|
+
) -> torch.LongTensor:
|
|
448
|
+
batch_size = pixel_values.shape[0]
|
|
449
|
+
image_embeds = self.vision_model(
|
|
450
|
+
pixel_values,
|
|
451
|
+
return_dict=True,
|
|
452
|
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
453
|
+
).last_hidden_state
|
|
454
|
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
455
|
+
|
|
456
|
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
457
|
+
query_outputs = self.qformer(
|
|
458
|
+
query_embeds=query_tokens,
|
|
459
|
+
encoder_hidden_states=image_embeds,
|
|
460
|
+
encoder_attention_mask=image_attention_mask,
|
|
461
|
+
return_dict=True,
|
|
462
|
+
)
|
|
463
|
+
query_output = query_outputs.last_hidden_state
|
|
464
|
+
|
|
465
|
+
if query_output.dtype != image_embeds.dtype:
|
|
466
|
+
query_output = query_output.to(image_embeds.dtype)
|
|
467
|
+
|
|
468
|
+
language_model_inputs = self.language_projection(query_output)
|
|
469
|
+
|
|
470
|
+
if inputs_embeds is None:
|
|
471
|
+
if input_ids is None:
|
|
472
|
+
image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
|
|
473
|
+
start_tokens = image_tokens + [self.config.text_config.bos_token_id]
|
|
474
|
+
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
|
475
|
+
input_ids = input_ids.repeat(batch_size, 1)
|
|
476
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
477
|
+
|
|
478
|
+
if attention_mask is None:
|
|
479
|
+
attention_mask = torch.ones_like(input_ids)
|
|
480
|
+
|
|
481
|
+
if input_ids is None:
|
|
482
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
483
|
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
484
|
+
)
|
|
485
|
+
special_image_mask = special_image_mask.all(-1)
|
|
486
|
+
else:
|
|
487
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
488
|
+
|
|
489
|
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
490
|
+
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
491
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
492
|
+
|
|
493
|
+
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
|
494
|
+
if not self.language_model.config.is_encoder_decoder:
|
|
495
|
+
inputs["input_ids"] = input_ids
|
|
496
|
+
|
|
497
|
+
outputs = self.language_model.generate(**inputs, **generate_kwargs)
|
|
498
|
+
|
|
499
|
+
return outputs
|
|
@@ -23,4 +23,5 @@ from ....ops import (
|
|
|
23
23
|
paged_flash_causal_attn_prefill,
|
|
24
24
|
)
|
|
25
25
|
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
26
|
+
from .configuration_lora import RBLNLoRAAdapterConfig, RBLNLoRAConfig
|
|
26
27
|
from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
|
|
@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, get_args
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
from ....utils.logging import get_logger
|
|
19
19
|
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
20
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
logger = get_logger()
|
|
@@ -48,6 +49,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
48
49
|
kvcache_partition_len: Optional[int] = None,
|
|
49
50
|
kvcache_block_size: Optional[int] = None,
|
|
50
51
|
quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
|
|
52
|
+
lora_config: Optional[Union[Dict[str, Any], RBLNLoRAConfig]] = None,
|
|
51
53
|
prefill_chunk_size: Optional[int] = None,
|
|
52
54
|
kvcache_num_blocks: Optional[int] = None,
|
|
53
55
|
decoder_batch_sizes: Optional[List[int]] = None,
|
|
@@ -80,6 +82,12 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
80
82
|
kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
|
|
81
83
|
in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
|
|
82
84
|
section below for details.
|
|
85
|
+
quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
|
|
86
|
+
quantization. Specifies format, etc.
|
|
87
|
+
lora_config (Optional[Union[Dict[str, Any], RBLNLoRAConfig]]): Configuration for LoRA
|
|
88
|
+
(Low-Rank Adaptation) settings when using (multi-)LoRA support. Can be provided as
|
|
89
|
+
a dictionary or an RBLNLoRAConfig instance. When provided, enables LoRA functionality
|
|
90
|
+
for the model compilation. Defaults to None (no LoRA).
|
|
83
91
|
prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
|
|
84
92
|
processing input sequences. Defaults to 128. Must be a positive integer
|
|
85
93
|
divisible by 64. Affects prefill performance and memory usage.
|
|
@@ -185,6 +193,26 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
185
193
|
if self.quantization and isinstance(self.quantization, dict):
|
|
186
194
|
self.quantization = RBLNQuantizationConfig(**self.quantization)
|
|
187
195
|
|
|
196
|
+
self.lora_config = lora_config
|
|
197
|
+
if self.lora_config and isinstance(self.lora_config, dict):
|
|
198
|
+
self.lora_config = RBLNLoRAConfig(**self.lora_config)
|
|
199
|
+
|
|
200
|
+
# Validate LoRA adapters if LoRA is enabled
|
|
201
|
+
if self.lora_config is not None:
|
|
202
|
+
validation_results = self.lora_config.validate_adapter_weights()
|
|
203
|
+
failed_adapters = [adapter_id for adapter_id, is_valid in validation_results.items() if not is_valid]
|
|
204
|
+
|
|
205
|
+
if failed_adapters:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Some LoRA adapters failed validation and may not be accessible at compile time: {failed_adapters}. "
|
|
208
|
+
"Please ensure all adapter weights are available and properly formatted."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
logger.info(
|
|
212
|
+
f"LoRA configuration initialized with {self.lora_config.num_adapters} adapters: "
|
|
213
|
+
f"{self.lora_config.adapter_ids}. Max rank: {self.lora_config.max_lora_rank}"
|
|
214
|
+
)
|
|
215
|
+
|
|
188
216
|
self.attn_impl = attn_impl
|
|
189
217
|
self.kvcache_partition_len = kvcache_partition_len
|
|
190
218
|
self.kvcache_block_size = kvcache_block_size
|
|
@@ -204,6 +232,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
204
232
|
if self.logits_to_keep is not None and self.logits_to_keep > 1:
|
|
205
233
|
raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.")
|
|
206
234
|
|
|
235
|
+
self.decoder_batch_sizes = None
|
|
207
236
|
if "decode" in self.phases:
|
|
208
237
|
self.decoder_batch_sizes = decoder_batch_sizes
|
|
209
238
|
if self.decoder_batch_sizes is None:
|
|
@@ -243,6 +272,11 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
243
272
|
def use_multiple_decoder(self) -> bool:
|
|
244
273
|
return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
|
|
245
274
|
|
|
275
|
+
@property
|
|
276
|
+
def use_lora(self):
|
|
277
|
+
"""Check if LoRA is enabled for this configuration."""
|
|
278
|
+
return self.lora_config is not None
|
|
279
|
+
|
|
246
280
|
@property
|
|
247
281
|
def can_generate(self) -> bool:
|
|
248
282
|
return "decode" in self.phases
|