optimum-rbln 0.9.2a2__py3-none-any.whl → 0.9.2a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +4 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +3 -0
- optimum/rbln/transformers/__init__.py +4 -0
- optimum/rbln/transformers/models/__init__.py +4 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +60 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +7 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
- {optimum_rbln-0.9.2a2.dist-info → optimum_rbln-0.9.2a4.dist-info}/METADATA +1 -1
- {optimum_rbln-0.9.2a2.dist-info → optimum_rbln-0.9.2a4.dist-info}/RECORD +26 -24
- {optimum_rbln-0.9.2a2.dist-info → optimum_rbln-0.9.2a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.2a2.dist-info → optimum_rbln-0.9.2a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -22,6 +22,8 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
|
22
22
|
from ....utils import logging
|
|
23
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
24
|
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
25
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
26
|
+
from .lora_architecture import LoRALinear
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
if TYPE_CHECKING:
|
|
@@ -52,12 +54,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
52
54
|
|
|
53
55
|
_use_learned_pos_emb = False
|
|
54
56
|
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
model: PreTrainedModel,
|
|
58
|
-
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
59
|
-
use_rotary_emb: bool,
|
|
60
|
-
):
|
|
57
|
+
def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
|
|
61
58
|
super().__init__()
|
|
62
59
|
self.quantization = rbln_config.quantization
|
|
63
60
|
self.config = model.config
|
|
@@ -114,7 +111,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
114
111
|
new_self_attn = self.get_rbln_attn_class()(
|
|
115
112
|
self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
|
|
116
113
|
)
|
|
117
|
-
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
114
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
|
|
118
115
|
new_layers.append(new_layer)
|
|
119
116
|
|
|
120
117
|
new_model = self.get_rbln_model_class()(
|
|
@@ -154,6 +151,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
154
151
|
)
|
|
155
152
|
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
156
153
|
position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
|
|
154
|
+
lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
|
|
157
155
|
past_key_values = args
|
|
158
156
|
|
|
159
157
|
if len(past_key_values) != 2 * self.num_hidden_layers:
|
|
@@ -185,6 +183,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
185
183
|
query_position,
|
|
186
184
|
attention_mask,
|
|
187
185
|
position_ids,
|
|
186
|
+
lora_int_id,
|
|
188
187
|
past_key_values,
|
|
189
188
|
rotary_emb,
|
|
190
189
|
)
|
|
@@ -199,6 +198,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
199
198
|
query_position,
|
|
200
199
|
attention_mask,
|
|
201
200
|
position_ids,
|
|
201
|
+
lora_int_id,
|
|
202
202
|
past_key_values,
|
|
203
203
|
rotary_emb,
|
|
204
204
|
) = self.prepare_forward_args(*args)
|
|
@@ -214,6 +214,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
214
214
|
rotary_emb=rotary_emb,
|
|
215
215
|
global_block_tables=global_block_tables,
|
|
216
216
|
local_block_tables=local_block_tables,
|
|
217
|
+
lora_int_id=lora_int_id,
|
|
217
218
|
)
|
|
218
219
|
|
|
219
220
|
return logit
|
|
@@ -270,6 +271,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
270
271
|
rotary_emb: nn.Module = None,
|
|
271
272
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
272
273
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
274
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
273
275
|
):
|
|
274
276
|
# outputs
|
|
275
277
|
hidden_states = self.model(
|
|
@@ -283,6 +285,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
283
285
|
rotary_emb=rotary_emb,
|
|
284
286
|
global_block_tables=global_block_tables,
|
|
285
287
|
local_block_tables=local_block_tables,
|
|
288
|
+
lora_int_id=lora_int_id,
|
|
286
289
|
)
|
|
287
290
|
|
|
288
291
|
if "prefill" in self.phase:
|
|
@@ -394,6 +397,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
394
397
|
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
|
395
398
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
396
399
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
400
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
397
401
|
):
|
|
398
402
|
# retrieve input_ids and inputs_embeds
|
|
399
403
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -466,6 +470,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
466
470
|
cos=cos,
|
|
467
471
|
sin=sin,
|
|
468
472
|
block_tables=local_block_tables if is_sliding else global_block_tables,
|
|
473
|
+
lora_int_id=lora_int_id,
|
|
469
474
|
)
|
|
470
475
|
|
|
471
476
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
@@ -497,11 +502,27 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
497
502
|
phase: Current operation phase ("prefill" or "decode")
|
|
498
503
|
"""
|
|
499
504
|
|
|
500
|
-
def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
|
|
505
|
+
def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
|
|
501
506
|
super().__init__()
|
|
502
507
|
self._original_mod = layer
|
|
503
508
|
self.self_attn = self_attn
|
|
504
509
|
self._phase = "prefill"
|
|
510
|
+
self.lora_config = lora_config
|
|
511
|
+
|
|
512
|
+
# Replace target Linear modules in MLP with LoRALinear if configured
|
|
513
|
+
if self.lora_config:
|
|
514
|
+
mlp = self.get_mlp()
|
|
515
|
+
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
|
|
516
|
+
if hasattr(mlp, proj_name):
|
|
517
|
+
original_linear = getattr(mlp, proj_name)
|
|
518
|
+
if isinstance(original_linear, nn.Linear):
|
|
519
|
+
lora_linear = LoRALinear(
|
|
520
|
+
original_linear=original_linear,
|
|
521
|
+
lora_config=self.lora_config,
|
|
522
|
+
projection_name=proj_name,
|
|
523
|
+
layer_idx=self.self_attn.layer_idx,
|
|
524
|
+
)
|
|
525
|
+
setattr(mlp, proj_name, lora_linear)
|
|
505
526
|
|
|
506
527
|
@property
|
|
507
528
|
def phase(self):
|
|
@@ -518,6 +539,25 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
518
539
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
519
540
|
return self._original_mod.post_attention_layernorm
|
|
520
541
|
|
|
542
|
+
def get_mlp(self) -> nn.Module:
|
|
543
|
+
return self._original_mod.mlp
|
|
544
|
+
|
|
545
|
+
def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
546
|
+
mlp = self.get_mlp()
|
|
547
|
+
if self.lora_config and lora_int_id is not None:
|
|
548
|
+
gate = mlp.gate_proj(hidden_states, lora_int_id)
|
|
549
|
+
up = mlp.up_proj(hidden_states, lora_int_id)
|
|
550
|
+
act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
|
|
551
|
+
if act_fn is None:
|
|
552
|
+
gate = torch.nn.functional.silu(gate)
|
|
553
|
+
else:
|
|
554
|
+
gate = act_fn(gate)
|
|
555
|
+
fused = gate * up
|
|
556
|
+
hidden_states = mlp.down_proj(fused, lora_int_id)
|
|
557
|
+
else:
|
|
558
|
+
hidden_states = mlp(hidden_states)
|
|
559
|
+
return hidden_states
|
|
560
|
+
|
|
521
561
|
def forward(
|
|
522
562
|
self,
|
|
523
563
|
hidden_states: torch.Tensor,
|
|
@@ -527,6 +567,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
527
567
|
cos: Optional[torch.Tensor] = None,
|
|
528
568
|
sin: Optional[torch.Tensor] = None,
|
|
529
569
|
block_tables: Optional[torch.Tensor] = None,
|
|
570
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
530
571
|
):
|
|
531
572
|
residual = hidden_states
|
|
532
573
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
@@ -539,13 +580,14 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
539
580
|
cos=cos,
|
|
540
581
|
sin=sin,
|
|
541
582
|
block_tables=block_tables,
|
|
583
|
+
lora_int_id=lora_int_id,
|
|
542
584
|
)
|
|
543
585
|
hidden_states = residual + hidden_states
|
|
544
586
|
|
|
545
587
|
# Fully Connected
|
|
546
588
|
residual = hidden_states
|
|
547
589
|
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
548
|
-
hidden_states = self.
|
|
590
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
549
591
|
hidden_states = residual + hidden_states
|
|
550
592
|
|
|
551
593
|
return hidden_states
|
|
@@ -595,10 +637,23 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
595
637
|
self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
|
|
596
638
|
self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
|
|
597
639
|
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
640
|
+
self.lora_config = rbln_config.lora_config
|
|
598
641
|
|
|
599
642
|
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
600
643
|
self.__post_init__()
|
|
601
644
|
|
|
645
|
+
def _init_lora_weights(self):
|
|
646
|
+
"""Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
|
|
647
|
+
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
648
|
+
original_linear = getattr(self._original_mod, proj_name)
|
|
649
|
+
lora_linear = LoRALinear(
|
|
650
|
+
original_linear=original_linear,
|
|
651
|
+
lora_config=self.lora_config,
|
|
652
|
+
projection_name=proj_name,
|
|
653
|
+
layer_idx=self.layer_idx,
|
|
654
|
+
)
|
|
655
|
+
setattr(self, proj_name, lora_linear)
|
|
656
|
+
|
|
602
657
|
def get_attention_name(self):
|
|
603
658
|
if self.is_sliding:
|
|
604
659
|
return "sliding_window_attention"
|
|
@@ -651,23 +706,40 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
651
706
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
652
707
|
|
|
653
708
|
def __post_init__(self):
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
709
|
+
# Initialize LoRA weights if configured, which will replace linear layers
|
|
710
|
+
if self.lora_config:
|
|
711
|
+
self._init_lora_weights()
|
|
712
|
+
else:
|
|
713
|
+
# Use original linear layers if no LoRA
|
|
714
|
+
self.q_proj = self._original_mod.q_proj
|
|
715
|
+
self.k_proj = self._original_mod.k_proj
|
|
716
|
+
self.v_proj = self._original_mod.v_proj
|
|
717
|
+
self.o_proj = self._original_mod.o_proj
|
|
718
|
+
|
|
719
|
+
def projection(
|
|
720
|
+
self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
|
|
721
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
660
722
|
"""Projects input hidden states into query, key, and value representations.
|
|
661
723
|
|
|
662
724
|
Args:
|
|
663
725
|
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
|
726
|
+
lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
|
|
664
727
|
|
|
665
728
|
Returns:
|
|
666
729
|
Tuple of (query_states, key_states, value_states)
|
|
667
730
|
"""
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
731
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
732
|
+
if self.lora_config:
|
|
733
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
734
|
+
query_states = self.q_proj(hidden_states, lora_int_id)
|
|
735
|
+
key_states = self.k_proj(hidden_states, lora_int_id)
|
|
736
|
+
value_states = self.v_proj(hidden_states, lora_int_id)
|
|
737
|
+
else:
|
|
738
|
+
# Standard linear projection without LoRA
|
|
739
|
+
query_states = self.q_proj(hidden_states)
|
|
740
|
+
key_states = self.k_proj(hidden_states)
|
|
741
|
+
value_states = self.v_proj(hidden_states)
|
|
742
|
+
|
|
671
743
|
return query_states, key_states, value_states
|
|
672
744
|
|
|
673
745
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
@@ -695,10 +767,11 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
695
767
|
cos: Optional[torch.Tensor] = None,
|
|
696
768
|
sin: Optional[torch.Tensor] = None,
|
|
697
769
|
block_tables: Optional[torch.Tensor] = None,
|
|
770
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
698
771
|
):
|
|
699
772
|
batch_size, query_length, _ = hidden_states.size()
|
|
700
773
|
|
|
701
|
-
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
774
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
|
|
702
775
|
|
|
703
776
|
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
704
777
|
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
@@ -732,7 +805,14 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
732
805
|
v_scale=v_scale,
|
|
733
806
|
)
|
|
734
807
|
|
|
735
|
-
|
|
808
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
809
|
+
if self.lora_config:
|
|
810
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
811
|
+
attn_outputs = self.o_proj(attn_output, lora_int_id)
|
|
812
|
+
else:
|
|
813
|
+
# Standard linear projection without LoRA
|
|
814
|
+
attn_outputs = self.o_proj(attn_output)
|
|
815
|
+
|
|
736
816
|
return attn_outputs
|
|
737
817
|
|
|
738
818
|
|
|
@@ -187,6 +187,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
187
187
|
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
188
188
|
)
|
|
189
189
|
|
|
190
|
+
self.lora_int_ids = None
|
|
191
|
+
|
|
190
192
|
def inputs_embeddings_if_needed(
|
|
191
193
|
self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
|
|
192
194
|
):
|
|
@@ -210,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
210
212
|
position_ids: Optional[torch.Tensor] = None,
|
|
211
213
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
212
214
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
215
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
213
216
|
):
|
|
214
217
|
inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
|
|
215
218
|
block_tables, local_block_tables, is_external_block_tables = (
|
|
@@ -233,6 +236,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
233
236
|
position_embed=position_embed,
|
|
234
237
|
position_ids=position_ids,
|
|
235
238
|
local_block_tables=local_block_tables,
|
|
239
|
+
lora_int_ids=lora_int_ids,
|
|
236
240
|
)
|
|
237
241
|
else:
|
|
238
242
|
return self.prefill_forward(
|
|
@@ -245,6 +249,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
245
249
|
position_embed=position_embed,
|
|
246
250
|
token_type_ids=token_type_ids,
|
|
247
251
|
local_block_tables=local_block_tables,
|
|
252
|
+
lora_int_ids=lora_int_ids,
|
|
248
253
|
)
|
|
249
254
|
|
|
250
255
|
def decode_forward(
|
|
@@ -257,7 +262,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
257
262
|
position_embed: Optional[torch.Tensor] = None,
|
|
258
263
|
position_ids: Optional[torch.Tensor] = None,
|
|
259
264
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
265
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
260
266
|
) -> torch.FloatTensor:
|
|
267
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
268
|
+
if self.lora_int_ids is None:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"lora_int_id is required when using LoRA. "
|
|
271
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
lora_int_ids = self.lora_int_ids
|
|
275
|
+
|
|
276
|
+
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
277
|
+
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
278
|
+
|
|
261
279
|
if self.batch_size != cache_position.shape[0]:
|
|
262
280
|
raise RuntimeError(
|
|
263
281
|
f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
|
|
@@ -287,6 +305,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
287
305
|
position_embed,
|
|
288
306
|
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
289
307
|
position_ids if self.rbln_config.use_position_ids else None,
|
|
308
|
+
lora_int_ids if self.rbln_config.use_lora else None,
|
|
290
309
|
)
|
|
291
310
|
|
|
292
311
|
return RBLNDecoderOnlyOutput(logits=logits)
|
|
@@ -369,12 +388,25 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
369
388
|
position_embed: Optional[torch.Tensor] = None,
|
|
370
389
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
371
390
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
391
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
372
392
|
) -> torch.FloatTensor:
|
|
373
393
|
"""
|
|
374
394
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
375
395
|
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
376
396
|
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
377
397
|
"""
|
|
398
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
399
|
+
if self.lora_int_ids is None:
|
|
400
|
+
raise ValueError(
|
|
401
|
+
"lora_int_id is required when using LoRA. "
|
|
402
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
if batch_idx is not None:
|
|
406
|
+
lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
|
|
407
|
+
else:
|
|
408
|
+
lora_int_ids = self.lora_int_ids.clone()
|
|
409
|
+
|
|
378
410
|
(
|
|
379
411
|
inputs,
|
|
380
412
|
cache_position,
|
|
@@ -426,6 +458,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
426
458
|
query_position,
|
|
427
459
|
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
428
460
|
position_ids_chunk,
|
|
461
|
+
lora_int_ids if self.rbln_config.use_lora else None,
|
|
429
462
|
out=self.out_buffers,
|
|
430
463
|
)
|
|
431
464
|
output_logits.append(output_logit)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import safetensors.torch
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from ....utils import logging
|
|
10
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger = logging.get_logger()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LoRALinear(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
A linear layer that supports multiple LoRA adapters compiled at static time.
|
|
19
|
+
|
|
20
|
+
This class replaces the original linear layer and handles both base weights
|
|
21
|
+
and multiple LoRA adapters in a single forward pass using custom ops.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
original_linear: nn.Linear,
|
|
27
|
+
lora_config: RBLNLoRAConfig,
|
|
28
|
+
projection_name: str = "",
|
|
29
|
+
layer_idx: int = 0,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
original_linear: The original linear layer to be replaced
|
|
34
|
+
lora_config: LoRA configuration containing all adapters
|
|
35
|
+
projection_name: Name of the projection (e.g., "q_proj", "k_proj")
|
|
36
|
+
layer_idx: Layer index for loading the correct LoRA weights
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self.in_features = original_linear.in_features
|
|
41
|
+
self.out_features = original_linear.out_features
|
|
42
|
+
self.projection_name = projection_name
|
|
43
|
+
self.layer_idx = layer_idx
|
|
44
|
+
self.lora_config = lora_config
|
|
45
|
+
|
|
46
|
+
# Store original linear weights and bias directly without cloning
|
|
47
|
+
self.register_buffer("weight", original_linear.weight.data)
|
|
48
|
+
if original_linear.bias is not None:
|
|
49
|
+
self.register_buffer("bias", original_linear.bias.data)
|
|
50
|
+
else:
|
|
51
|
+
self.bias = None
|
|
52
|
+
|
|
53
|
+
# Initialize LoRA weights
|
|
54
|
+
self._init_lora_weights()
|
|
55
|
+
|
|
56
|
+
def _should_apply_lora(self) -> bool:
|
|
57
|
+
"""Check if this projection should have LoRA applied."""
|
|
58
|
+
# Check if any adapter targets this projection
|
|
59
|
+
return any(self.projection_name in adapter.target_modules for adapter in self.lora_config.adapters)
|
|
60
|
+
|
|
61
|
+
def _load_adapter_weights(self, adapter_path: Path):
|
|
62
|
+
"""
|
|
63
|
+
Load adapter weights from local directory.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
adapter_path: Path to local directory containing adapter weights
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary containing adapter weights
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
FileNotFoundError: If no adapter weights are found in the directory
|
|
73
|
+
"""
|
|
74
|
+
if not adapter_path.is_dir():
|
|
75
|
+
raise ValueError(f"Adapter path must be a directory, got: {adapter_path}")
|
|
76
|
+
|
|
77
|
+
# Try to load weights in order of preference
|
|
78
|
+
weight_files = [
|
|
79
|
+
("adapter_model.safetensors", lambda p: safetensors.torch.load_file(p)),
|
|
80
|
+
("adapter_model.bin", lambda p: torch.load(p, map_location="cpu")),
|
|
81
|
+
("pytorch_model.bin", lambda p: torch.load(p, map_location="cpu")),
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
for filename, load_fn in weight_files:
|
|
85
|
+
weight_path = adapter_path / filename
|
|
86
|
+
if weight_path.exists():
|
|
87
|
+
return load_fn(weight_path)
|
|
88
|
+
|
|
89
|
+
raise FileNotFoundError(
|
|
90
|
+
f"No adapter weights found in {adapter_path}. "
|
|
91
|
+
f"Expected one of: {', '.join(filename for filename, _ in weight_files)}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _init_lora_weights(self):
|
|
95
|
+
"""Initialize LoRA adapter weights by loading and stacking them."""
|
|
96
|
+
|
|
97
|
+
lora_a_weights = []
|
|
98
|
+
lora_b_weights = []
|
|
99
|
+
|
|
100
|
+
for adapter in self.lora_config.adapters:
|
|
101
|
+
if self.projection_name not in adapter.target_modules:
|
|
102
|
+
# Create zero weights for adapters that don't target this projection
|
|
103
|
+
lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
|
|
104
|
+
lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
adapter_weights = self._load_adapter_weights(adapter.local_adapter_path)
|
|
108
|
+
|
|
109
|
+
# Determine module type from projection name
|
|
110
|
+
attn_projs = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
|
111
|
+
mlp_projs = {"gate_proj", "up_proj", "down_proj"}
|
|
112
|
+
if self.projection_name in attn_projs:
|
|
113
|
+
module_type = "self_attn"
|
|
114
|
+
elif self.projection_name in mlp_projs:
|
|
115
|
+
module_type = "mlp"
|
|
116
|
+
else:
|
|
117
|
+
module_type = "self_attn"
|
|
118
|
+
|
|
119
|
+
layer_key = f"base_model.model.model.layers.{self.layer_idx}.{module_type}.{self.projection_name}"
|
|
120
|
+
lora_a_key = f"{layer_key}.lora_A.weight"
|
|
121
|
+
lora_b_key = f"{layer_key}.lora_B.weight"
|
|
122
|
+
|
|
123
|
+
if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
|
|
124
|
+
# Calculate scaling factor and fold it into lora_b
|
|
125
|
+
scaling = adapter.lora_alpha / adapter.r
|
|
126
|
+
if adapter.use_rslora:
|
|
127
|
+
scaling = scaling / math.sqrt(adapter.r)
|
|
128
|
+
scaling = scaling * adapter.scaling_factor
|
|
129
|
+
|
|
130
|
+
lora_a_weights.append(adapter_weights[lora_a_key])
|
|
131
|
+
# scaling is pre-applied to lora_b_weights
|
|
132
|
+
lora_b_weights.append(adapter_weights[lora_b_key] * scaling)
|
|
133
|
+
else:
|
|
134
|
+
logger.warning(f"No LoRA weights found for {lora_a_key} or {lora_b_key}")
|
|
135
|
+
lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
|
|
136
|
+
lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
|
|
137
|
+
|
|
138
|
+
# Stack weights along adapter dimension
|
|
139
|
+
max_rank = self.lora_config.max_lora_rank
|
|
140
|
+
|
|
141
|
+
# Pad smaller ranks to max_rank
|
|
142
|
+
padded_lora_a = []
|
|
143
|
+
padded_lora_b = []
|
|
144
|
+
|
|
145
|
+
for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
|
|
146
|
+
current_rank = lora_a.shape[0]
|
|
147
|
+
if current_rank < max_rank:
|
|
148
|
+
# Pad with zeros
|
|
149
|
+
padded_a = torch.zeros(max_rank, self.in_features)
|
|
150
|
+
padded_b = torch.zeros(self.out_features, max_rank)
|
|
151
|
+
padded_a[:current_rank] = lora_a
|
|
152
|
+
padded_b[:, :current_rank] = lora_b
|
|
153
|
+
padded_lora_a.append(padded_a)
|
|
154
|
+
padded_lora_b.append(padded_b)
|
|
155
|
+
else:
|
|
156
|
+
padded_lora_a.append(lora_a)
|
|
157
|
+
padded_lora_b.append(lora_b)
|
|
158
|
+
|
|
159
|
+
lora_a_transposed = [lora_a.transpose(0, 1) for lora_a in padded_lora_a] # [in_features, rank]
|
|
160
|
+
lora_b_transposed = [lora_b.transpose(0, 1) for lora_b in padded_lora_b] # [rank, out_features]
|
|
161
|
+
|
|
162
|
+
self.register_buffer(
|
|
163
|
+
"lora_a_weights", torch.stack(lora_a_transposed, dim=0).to(self.weight.dtype)
|
|
164
|
+
) # [num_adapters, in_features, rank]
|
|
165
|
+
self.register_buffer(
|
|
166
|
+
"lora_b_weights", torch.stack(lora_b_transposed, dim=0).to(self.weight.dtype)
|
|
167
|
+
) # [num_adapters, rank, out_features]
|
|
168
|
+
|
|
169
|
+
def forward(self, x: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
170
|
+
"""
|
|
171
|
+
Forward pass that combines base linear transformation with LoRA.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
x: Input tensor [batch_size, seq_len, in_features]
|
|
175
|
+
lora_int_id: Adapter ID tensor [batch_size] indicating which adapter to use
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Output tensor [batch_size, seq_len, out_features]
|
|
179
|
+
"""
|
|
180
|
+
# Base linear transformation
|
|
181
|
+
output = torch.nn.functional.linear(x, self.weight, self.bias)
|
|
182
|
+
|
|
183
|
+
# Apply LoRA if enabled and adapter ID is provided
|
|
184
|
+
if self._should_apply_lora() and lora_int_id is not None:
|
|
185
|
+
# Gather LoRA weights for each batch item
|
|
186
|
+
# lora_int_id: [batch_size] -> use as indices to select weights
|
|
187
|
+
selected_lora_a = self.lora_a_weights[lora_int_id] # [batch_size, in_features, rank]
|
|
188
|
+
selected_lora_b = self.lora_b_weights[lora_int_id] # [batch_size, rank, out_features]
|
|
189
|
+
|
|
190
|
+
# Batched matrix multiplication for LoRA computation
|
|
191
|
+
# x: [batch_size, seq_len, in_features]
|
|
192
|
+
# selected_lora_a: [batch_size, in_features, rank] (already transposed)
|
|
193
|
+
# selected_lora_b: [batch_size, rank, out_features] (already transposed)
|
|
194
|
+
|
|
195
|
+
# First matmul: x @ lora_a -> [batch_size, seq_len, rank]
|
|
196
|
+
temp = torch.bmm(x, selected_lora_a)
|
|
197
|
+
|
|
198
|
+
# Second matmul: temp @ lora_b -> [batch_size, seq_len, out_features]
|
|
199
|
+
lora_delta = torch.bmm(temp, selected_lora_b)
|
|
200
|
+
|
|
201
|
+
# Add LoRA delta to base output
|
|
202
|
+
output = output + lora_delta
|
|
203
|
+
|
|
204
|
+
return output
|
|
@@ -375,6 +375,9 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
375
375
|
if rbln_config.use_position_ids:
|
|
376
376
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
|
377
377
|
|
|
378
|
+
if rbln_config.use_lora:
|
|
379
|
+
input_info.append(("lora_int_ids", [batch_size], "int32"))
|
|
380
|
+
|
|
378
381
|
kvcache_dtype = rbln_config.torch_dtype
|
|
379
382
|
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
380
383
|
kvcache_dtype = "float8_e4m3fn"
|
|
@@ -667,6 +670,53 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
667
670
|
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
668
671
|
return is_prefill
|
|
669
672
|
|
|
673
|
+
def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
|
|
674
|
+
if isinstance(lora_int_ids, int):
|
|
675
|
+
lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
|
|
676
|
+
elif isinstance(lora_int_ids, list):
|
|
677
|
+
lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
|
|
678
|
+
|
|
679
|
+
self.lora_int_ids = lora_int_ids
|
|
680
|
+
|
|
681
|
+
self.prefill_decoder.lora_int_ids = lora_int_ids
|
|
682
|
+
if self.rbln_config.can_generate:
|
|
683
|
+
for batch_size in self.rbln_config.decoder_batch_sizes:
|
|
684
|
+
self.decoders[batch_size].lora_int_ids = lora_int_ids
|
|
685
|
+
|
|
686
|
+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
|
687
|
+
"""
|
|
688
|
+
Sets the active adapter(s) for the model using adapter name(s).
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
|
|
692
|
+
Can be a single adapter name or a list of adapter names.
|
|
693
|
+
|
|
694
|
+
Raises:
|
|
695
|
+
ValueError: If the model is not configured with LoRA or if the adapter name is not found.
|
|
696
|
+
"""
|
|
697
|
+
if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
|
|
698
|
+
raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
|
|
699
|
+
|
|
700
|
+
# Convert single adapter name to list for uniform processing
|
|
701
|
+
if isinstance(adapter_name, str):
|
|
702
|
+
adapter_names = [adapter_name]
|
|
703
|
+
else:
|
|
704
|
+
adapter_names = adapter_name
|
|
705
|
+
|
|
706
|
+
# Validate that all adapter names exist
|
|
707
|
+
available_adapters = {
|
|
708
|
+
adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
|
|
709
|
+
}
|
|
710
|
+
missing_adapters = [name for name in adapter_names if name not in available_adapters]
|
|
711
|
+
if missing_adapters:
|
|
712
|
+
raise ValueError(
|
|
713
|
+
f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
# Get the adapter IDs and set them
|
|
717
|
+
lora_int_ids = [available_adapters[name] for name in adapter_names]
|
|
718
|
+
self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
|
|
719
|
+
|
|
670
720
|
def forward(
|
|
671
721
|
self,
|
|
672
722
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -677,6 +727,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
677
727
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
678
728
|
position_ids: Optional[torch.Tensor] = None,
|
|
679
729
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
730
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
680
731
|
return_dict: Optional[torch.Tensor] = None,
|
|
681
732
|
**kwargs,
|
|
682
733
|
) -> Tuple[torch.FloatTensor]:
|
|
@@ -684,6 +735,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
684
735
|
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
|
685
736
|
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
686
737
|
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
738
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
739
|
+
if self.lora_int_ids is None:
|
|
740
|
+
raise ValueError(
|
|
741
|
+
"lora_int_id is required when using LoRA. "
|
|
742
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
743
|
+
)
|
|
744
|
+
lora_int_ids = self.lora_int_ids
|
|
687
745
|
|
|
688
746
|
# for only use forward
|
|
689
747
|
if generate_idx is None:
|
|
@@ -708,6 +766,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
708
766
|
cache_position=cache_position,
|
|
709
767
|
batch_idx=b_idx,
|
|
710
768
|
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
769
|
+
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
711
770
|
)
|
|
712
771
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
713
772
|
logits.append(output.logits)
|
|
@@ -727,6 +786,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
727
786
|
inputs_embeds=inputs_embeds,
|
|
728
787
|
cache_position=cache_position,
|
|
729
788
|
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
789
|
+
lora_int_ids=lora_int_ids,
|
|
730
790
|
).logits
|
|
731
791
|
|
|
732
792
|
if not return_dict:
|