optimum-rbln 0.9.2a2__py3-none-any.whl → 0.9.2a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (26) hide show
  1. optimum/rbln/__init__.py +4 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +3 -0
  4. optimum/rbln/transformers/__init__.py +4 -0
  5. optimum/rbln/transformers/models/__init__.py +4 -0
  6. optimum/rbln/transformers/models/colpali/modeling_colpali.py +1 -1
  7. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  8. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  9. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  10. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
  12. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  13. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +60 -0
  14. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  15. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  16. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +7 -0
  17. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  18. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  19. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  20. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
  21. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
  22. optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
  23. {optimum_rbln-0.9.2a2.dist-info → optimum_rbln-0.9.2a4.dist-info}/METADATA +1 -1
  24. {optimum_rbln-0.9.2a2.dist-info → optimum_rbln-0.9.2a4.dist-info}/RECORD +26 -24
  25. {optimum_rbln-0.9.2a2.dist-info → optimum_rbln-0.9.2a4.dist-info}/WHEEL +0 -0
  26. {optimum_rbln-0.9.2a2.dist-info → optimum_rbln-0.9.2a4.dist-info}/licenses/LICENSE +0 -0
@@ -22,6 +22,8 @@ from transformers import PretrainedConfig, PreTrainedModel
22
22
  from ....utils import logging
23
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
24
  from ...utils.rbln_quantization import RBLNQuantizationConfig
25
+ from .configuration_lora import RBLNLoRAConfig
26
+ from .lora_architecture import LoRALinear
25
27
 
26
28
 
27
29
  if TYPE_CHECKING:
@@ -52,12 +54,7 @@ class DecoderOnlyWrapper(nn.Module):
52
54
 
53
55
  _use_learned_pos_emb = False
54
56
 
55
- def __init__(
56
- self,
57
- model: PreTrainedModel,
58
- rbln_config: "RBLNDecoderOnlyModelConfig",
59
- use_rotary_emb: bool,
60
- ):
57
+ def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
61
58
  super().__init__()
62
59
  self.quantization = rbln_config.quantization
63
60
  self.config = model.config
@@ -114,7 +111,7 @@ class DecoderOnlyWrapper(nn.Module):
114
111
  new_self_attn = self.get_rbln_attn_class()(
115
112
  self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
116
113
  )
117
- new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
114
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
118
115
  new_layers.append(new_layer)
119
116
 
120
117
  new_model = self.get_rbln_model_class()(
@@ -154,6 +151,7 @@ class DecoderOnlyWrapper(nn.Module):
154
151
  )
155
152
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
156
153
  position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
154
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
157
155
  past_key_values = args
158
156
 
159
157
  if len(past_key_values) != 2 * self.num_hidden_layers:
@@ -185,6 +183,7 @@ class DecoderOnlyWrapper(nn.Module):
185
183
  query_position,
186
184
  attention_mask,
187
185
  position_ids,
186
+ lora_int_id,
188
187
  past_key_values,
189
188
  rotary_emb,
190
189
  )
@@ -199,6 +198,7 @@ class DecoderOnlyWrapper(nn.Module):
199
198
  query_position,
200
199
  attention_mask,
201
200
  position_ids,
201
+ lora_int_id,
202
202
  past_key_values,
203
203
  rotary_emb,
204
204
  ) = self.prepare_forward_args(*args)
@@ -214,6 +214,7 @@ class DecoderOnlyWrapper(nn.Module):
214
214
  rotary_emb=rotary_emb,
215
215
  global_block_tables=global_block_tables,
216
216
  local_block_tables=local_block_tables,
217
+ lora_int_id=lora_int_id,
217
218
  )
218
219
 
219
220
  return logit
@@ -270,6 +271,7 @@ class DecoderOnlyForCausalLM(nn.Module):
270
271
  rotary_emb: nn.Module = None,
271
272
  global_block_tables: Optional[torch.Tensor] = None,
272
273
  local_block_tables: Optional[torch.Tensor] = None,
274
+ lora_int_id: Optional[torch.Tensor] = None,
273
275
  ):
274
276
  # outputs
275
277
  hidden_states = self.model(
@@ -283,6 +285,7 @@ class DecoderOnlyForCausalLM(nn.Module):
283
285
  rotary_emb=rotary_emb,
284
286
  global_block_tables=global_block_tables,
285
287
  local_block_tables=local_block_tables,
288
+ lora_int_id=lora_int_id,
286
289
  )
287
290
 
288
291
  if "prefill" in self.phase:
@@ -394,6 +397,7 @@ class DecoderOnlyModel(nn.Module):
394
397
  rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
395
398
  global_block_tables: Optional[torch.Tensor] = None,
396
399
  local_block_tables: Optional[torch.Tensor] = None,
400
+ lora_int_id: Optional[torch.Tensor] = None,
397
401
  ):
398
402
  # retrieve input_ids and inputs_embeds
399
403
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -466,6 +470,7 @@ class DecoderOnlyModel(nn.Module):
466
470
  cos=cos,
467
471
  sin=sin,
468
472
  block_tables=local_block_tables if is_sliding else global_block_tables,
473
+ lora_int_id=lora_int_id,
469
474
  )
470
475
 
471
476
  hidden_states = self.get_last_layernorm()(hidden_states)
@@ -497,11 +502,27 @@ class DecoderOnlyLayer(nn.Module):
497
502
  phase: Current operation phase ("prefill" or "decode")
498
503
  """
499
504
 
500
- def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
505
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
501
506
  super().__init__()
502
507
  self._original_mod = layer
503
508
  self.self_attn = self_attn
504
509
  self._phase = "prefill"
510
+ self.lora_config = lora_config
511
+
512
+ # Replace target Linear modules in MLP with LoRALinear if configured
513
+ if self.lora_config:
514
+ mlp = self.get_mlp()
515
+ for proj_name in ["gate_proj", "up_proj", "down_proj"]:
516
+ if hasattr(mlp, proj_name):
517
+ original_linear = getattr(mlp, proj_name)
518
+ if isinstance(original_linear, nn.Linear):
519
+ lora_linear = LoRALinear(
520
+ original_linear=original_linear,
521
+ lora_config=self.lora_config,
522
+ projection_name=proj_name,
523
+ layer_idx=self.self_attn.layer_idx,
524
+ )
525
+ setattr(mlp, proj_name, lora_linear)
505
526
 
506
527
  @property
507
528
  def phase(self):
@@ -518,6 +539,25 @@ class DecoderOnlyLayer(nn.Module):
518
539
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
519
540
  return self._original_mod.post_attention_layernorm
520
541
 
542
+ def get_mlp(self) -> nn.Module:
543
+ return self._original_mod.mlp
544
+
545
+ def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
546
+ mlp = self.get_mlp()
547
+ if self.lora_config and lora_int_id is not None:
548
+ gate = mlp.gate_proj(hidden_states, lora_int_id)
549
+ up = mlp.up_proj(hidden_states, lora_int_id)
550
+ act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
551
+ if act_fn is None:
552
+ gate = torch.nn.functional.silu(gate)
553
+ else:
554
+ gate = act_fn(gate)
555
+ fused = gate * up
556
+ hidden_states = mlp.down_proj(fused, lora_int_id)
557
+ else:
558
+ hidden_states = mlp(hidden_states)
559
+ return hidden_states
560
+
521
561
  def forward(
522
562
  self,
523
563
  hidden_states: torch.Tensor,
@@ -527,6 +567,7 @@ class DecoderOnlyLayer(nn.Module):
527
567
  cos: Optional[torch.Tensor] = None,
528
568
  sin: Optional[torch.Tensor] = None,
529
569
  block_tables: Optional[torch.Tensor] = None,
570
+ lora_int_id: Optional[torch.Tensor] = None,
530
571
  ):
531
572
  residual = hidden_states
532
573
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
@@ -539,13 +580,14 @@ class DecoderOnlyLayer(nn.Module):
539
580
  cos=cos,
540
581
  sin=sin,
541
582
  block_tables=block_tables,
583
+ lora_int_id=lora_int_id,
542
584
  )
543
585
  hidden_states = residual + hidden_states
544
586
 
545
587
  # Fully Connected
546
588
  residual = hidden_states
547
589
  hidden_states = self.get_post_attention_layernorm()(hidden_states)
548
- hidden_states = self._original_mod.mlp(hidden_states)
590
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
549
591
  hidden_states = residual + hidden_states
550
592
 
551
593
  return hidden_states
@@ -595,10 +637,23 @@ class DecoderOnlyAttention(nn.Module):
595
637
  self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
596
638
  self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
597
639
  self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
640
+ self.lora_config = rbln_config.lora_config
598
641
 
599
642
  setattr(self, self.get_attention_name(), self.create_attention_op())
600
643
  self.__post_init__()
601
644
 
645
+ def _init_lora_weights(self):
646
+ """Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
647
+ for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
648
+ original_linear = getattr(self._original_mod, proj_name)
649
+ lora_linear = LoRALinear(
650
+ original_linear=original_linear,
651
+ lora_config=self.lora_config,
652
+ projection_name=proj_name,
653
+ layer_idx=self.layer_idx,
654
+ )
655
+ setattr(self, proj_name, lora_linear)
656
+
602
657
  def get_attention_name(self):
603
658
  if self.is_sliding:
604
659
  return "sliding_window_attention"
@@ -651,23 +706,40 @@ class DecoderOnlyAttention(nn.Module):
651
706
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
652
707
 
653
708
  def __post_init__(self):
654
- self.q_proj = self._original_mod.q_proj
655
- self.k_proj = self._original_mod.k_proj
656
- self.v_proj = self._original_mod.v_proj
657
- self.o_proj = self._original_mod.o_proj
658
-
659
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
709
+ # Initialize LoRA weights if configured, which will replace linear layers
710
+ if self.lora_config:
711
+ self._init_lora_weights()
712
+ else:
713
+ # Use original linear layers if no LoRA
714
+ self.q_proj = self._original_mod.q_proj
715
+ self.k_proj = self._original_mod.k_proj
716
+ self.v_proj = self._original_mod.v_proj
717
+ self.o_proj = self._original_mod.o_proj
718
+
719
+ def projection(
720
+ self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
721
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
660
722
  """Projects input hidden states into query, key, and value representations.
661
723
 
662
724
  Args:
663
725
  hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
726
+ lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
664
727
 
665
728
  Returns:
666
729
  Tuple of (query_states, key_states, value_states)
667
730
  """
668
- query_states = self.q_proj(hidden_states)
669
- key_states = self.k_proj(hidden_states)
670
- value_states = self.v_proj(hidden_states)
731
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
732
+ if self.lora_config:
733
+ # LoRALinear handles both base projection and LoRA in one forward pass
734
+ query_states = self.q_proj(hidden_states, lora_int_id)
735
+ key_states = self.k_proj(hidden_states, lora_int_id)
736
+ value_states = self.v_proj(hidden_states, lora_int_id)
737
+ else:
738
+ # Standard linear projection without LoRA
739
+ query_states = self.q_proj(hidden_states)
740
+ key_states = self.k_proj(hidden_states)
741
+ value_states = self.v_proj(hidden_states)
742
+
671
743
  return query_states, key_states, value_states
672
744
 
673
745
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
@@ -695,10 +767,11 @@ class DecoderOnlyAttention(nn.Module):
695
767
  cos: Optional[torch.Tensor] = None,
696
768
  sin: Optional[torch.Tensor] = None,
697
769
  block_tables: Optional[torch.Tensor] = None,
770
+ lora_int_id: Optional[torch.Tensor] = None,
698
771
  ):
699
772
  batch_size, query_length, _ = hidden_states.size()
700
773
 
701
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
774
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
702
775
 
703
776
  query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
704
777
  key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -732,7 +805,14 @@ class DecoderOnlyAttention(nn.Module):
732
805
  v_scale=v_scale,
733
806
  )
734
807
 
735
- attn_outputs = self.o_proj(attn_output)
808
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
809
+ if self.lora_config:
810
+ # LoRALinear handles both base projection and LoRA in one forward pass
811
+ attn_outputs = self.o_proj(attn_output, lora_int_id)
812
+ else:
813
+ # Standard linear projection without LoRA
814
+ attn_outputs = self.o_proj(attn_output)
815
+
736
816
  return attn_outputs
737
817
 
738
818
 
@@ -187,6 +187,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
187
187
  torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
188
188
  )
189
189
 
190
+ self.lora_int_ids = None
191
+
190
192
  def inputs_embeddings_if_needed(
191
193
  self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
192
194
  ):
@@ -210,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
210
212
  position_ids: Optional[torch.Tensor] = None,
211
213
  token_type_ids: Optional[torch.Tensor] = None,
212
214
  local_block_tables: Optional[torch.Tensor] = None,
215
+ lora_int_ids: Optional[torch.Tensor] = None,
213
216
  ):
214
217
  inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
215
218
  block_tables, local_block_tables, is_external_block_tables = (
@@ -233,6 +236,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
233
236
  position_embed=position_embed,
234
237
  position_ids=position_ids,
235
238
  local_block_tables=local_block_tables,
239
+ lora_int_ids=lora_int_ids,
236
240
  )
237
241
  else:
238
242
  return self.prefill_forward(
@@ -245,6 +249,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
245
249
  position_embed=position_embed,
246
250
  token_type_ids=token_type_ids,
247
251
  local_block_tables=local_block_tables,
252
+ lora_int_ids=lora_int_ids,
248
253
  )
249
254
 
250
255
  def decode_forward(
@@ -257,7 +262,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
257
262
  position_embed: Optional[torch.Tensor] = None,
258
263
  position_ids: Optional[torch.Tensor] = None,
259
264
  local_block_tables: Optional[torch.Tensor] = None,
265
+ lora_int_ids: Optional[torch.Tensor] = None,
260
266
  ) -> torch.FloatTensor:
267
+ if self.rbln_config.use_lora and lora_int_ids is None:
268
+ if self.lora_int_ids is None:
269
+ raise ValueError(
270
+ "lora_int_id is required when using LoRA. "
271
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
272
+ )
273
+
274
+ lora_int_ids = self.lora_int_ids
275
+
276
+ if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
277
+ raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
278
+
261
279
  if self.batch_size != cache_position.shape[0]:
262
280
  raise RuntimeError(
263
281
  f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
@@ -287,6 +305,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
287
305
  position_embed,
288
306
  attention_mask if self.rbln_config.use_attention_mask else None,
289
307
  position_ids if self.rbln_config.use_position_ids else None,
308
+ lora_int_ids if self.rbln_config.use_lora else None,
290
309
  )
291
310
 
292
311
  return RBLNDecoderOnlyOutput(logits=logits)
@@ -369,12 +388,25 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
369
388
  position_embed: Optional[torch.Tensor] = None,
370
389
  token_type_ids: Optional[torch.Tensor] = None,
371
390
  local_block_tables: Optional[torch.Tensor] = None,
391
+ lora_int_ids: Optional[torch.Tensor] = None,
372
392
  ) -> torch.FloatTensor:
373
393
  """
374
394
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
375
395
  Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
376
396
  and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
377
397
  """
398
+ if self.rbln_config.use_lora and lora_int_ids is None:
399
+ if self.lora_int_ids is None:
400
+ raise ValueError(
401
+ "lora_int_id is required when using LoRA. "
402
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
403
+ )
404
+
405
+ if batch_idx is not None:
406
+ lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
407
+ else:
408
+ lora_int_ids = self.lora_int_ids.clone()
409
+
378
410
  (
379
411
  inputs,
380
412
  cache_position,
@@ -426,6 +458,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
426
458
  query_position,
427
459
  chunked_attention_mask if self.rbln_config.use_attention_mask else None,
428
460
  position_ids_chunk,
461
+ lora_int_ids if self.rbln_config.use_lora else None,
429
462
  out=self.out_buffers,
430
463
  )
431
464
  output_logits.append(output_logit)
@@ -0,0 +1,204 @@
1
+ import math
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import safetensors.torch
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ....utils import logging
10
+ from .configuration_lora import RBLNLoRAConfig
11
+
12
+
13
+ logger = logging.get_logger()
14
+
15
+
16
+ class LoRALinear(nn.Module):
17
+ """
18
+ A linear layer that supports multiple LoRA adapters compiled at static time.
19
+
20
+ This class replaces the original linear layer and handles both base weights
21
+ and multiple LoRA adapters in a single forward pass using custom ops.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ original_linear: nn.Linear,
27
+ lora_config: RBLNLoRAConfig,
28
+ projection_name: str = "",
29
+ layer_idx: int = 0,
30
+ ):
31
+ """
32
+ Args:
33
+ original_linear: The original linear layer to be replaced
34
+ lora_config: LoRA configuration containing all adapters
35
+ projection_name: Name of the projection (e.g., "q_proj", "k_proj")
36
+ layer_idx: Layer index for loading the correct LoRA weights
37
+ """
38
+ super().__init__()
39
+
40
+ self.in_features = original_linear.in_features
41
+ self.out_features = original_linear.out_features
42
+ self.projection_name = projection_name
43
+ self.layer_idx = layer_idx
44
+ self.lora_config = lora_config
45
+
46
+ # Store original linear weights and bias directly without cloning
47
+ self.register_buffer("weight", original_linear.weight.data)
48
+ if original_linear.bias is not None:
49
+ self.register_buffer("bias", original_linear.bias.data)
50
+ else:
51
+ self.bias = None
52
+
53
+ # Initialize LoRA weights
54
+ self._init_lora_weights()
55
+
56
+ def _should_apply_lora(self) -> bool:
57
+ """Check if this projection should have LoRA applied."""
58
+ # Check if any adapter targets this projection
59
+ return any(self.projection_name in adapter.target_modules for adapter in self.lora_config.adapters)
60
+
61
+ def _load_adapter_weights(self, adapter_path: Path):
62
+ """
63
+ Load adapter weights from local directory.
64
+
65
+ Args:
66
+ adapter_path: Path to local directory containing adapter weights
67
+
68
+ Returns:
69
+ Dictionary containing adapter weights
70
+
71
+ Raises:
72
+ FileNotFoundError: If no adapter weights are found in the directory
73
+ """
74
+ if not adapter_path.is_dir():
75
+ raise ValueError(f"Adapter path must be a directory, got: {adapter_path}")
76
+
77
+ # Try to load weights in order of preference
78
+ weight_files = [
79
+ ("adapter_model.safetensors", lambda p: safetensors.torch.load_file(p)),
80
+ ("adapter_model.bin", lambda p: torch.load(p, map_location="cpu")),
81
+ ("pytorch_model.bin", lambda p: torch.load(p, map_location="cpu")),
82
+ ]
83
+
84
+ for filename, load_fn in weight_files:
85
+ weight_path = adapter_path / filename
86
+ if weight_path.exists():
87
+ return load_fn(weight_path)
88
+
89
+ raise FileNotFoundError(
90
+ f"No adapter weights found in {adapter_path}. "
91
+ f"Expected one of: {', '.join(filename for filename, _ in weight_files)}"
92
+ )
93
+
94
+ def _init_lora_weights(self):
95
+ """Initialize LoRA adapter weights by loading and stacking them."""
96
+
97
+ lora_a_weights = []
98
+ lora_b_weights = []
99
+
100
+ for adapter in self.lora_config.adapters:
101
+ if self.projection_name not in adapter.target_modules:
102
+ # Create zero weights for adapters that don't target this projection
103
+ lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
104
+ lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
105
+ continue
106
+
107
+ adapter_weights = self._load_adapter_weights(adapter.local_adapter_path)
108
+
109
+ # Determine module type from projection name
110
+ attn_projs = {"q_proj", "k_proj", "v_proj", "o_proj"}
111
+ mlp_projs = {"gate_proj", "up_proj", "down_proj"}
112
+ if self.projection_name in attn_projs:
113
+ module_type = "self_attn"
114
+ elif self.projection_name in mlp_projs:
115
+ module_type = "mlp"
116
+ else:
117
+ module_type = "self_attn"
118
+
119
+ layer_key = f"base_model.model.model.layers.{self.layer_idx}.{module_type}.{self.projection_name}"
120
+ lora_a_key = f"{layer_key}.lora_A.weight"
121
+ lora_b_key = f"{layer_key}.lora_B.weight"
122
+
123
+ if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
124
+ # Calculate scaling factor and fold it into lora_b
125
+ scaling = adapter.lora_alpha / adapter.r
126
+ if adapter.use_rslora:
127
+ scaling = scaling / math.sqrt(adapter.r)
128
+ scaling = scaling * adapter.scaling_factor
129
+
130
+ lora_a_weights.append(adapter_weights[lora_a_key])
131
+ # scaling is pre-applied to lora_b_weights
132
+ lora_b_weights.append(adapter_weights[lora_b_key] * scaling)
133
+ else:
134
+ logger.warning(f"No LoRA weights found for {lora_a_key} or {lora_b_key}")
135
+ lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
136
+ lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
137
+
138
+ # Stack weights along adapter dimension
139
+ max_rank = self.lora_config.max_lora_rank
140
+
141
+ # Pad smaller ranks to max_rank
142
+ padded_lora_a = []
143
+ padded_lora_b = []
144
+
145
+ for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
146
+ current_rank = lora_a.shape[0]
147
+ if current_rank < max_rank:
148
+ # Pad with zeros
149
+ padded_a = torch.zeros(max_rank, self.in_features)
150
+ padded_b = torch.zeros(self.out_features, max_rank)
151
+ padded_a[:current_rank] = lora_a
152
+ padded_b[:, :current_rank] = lora_b
153
+ padded_lora_a.append(padded_a)
154
+ padded_lora_b.append(padded_b)
155
+ else:
156
+ padded_lora_a.append(lora_a)
157
+ padded_lora_b.append(lora_b)
158
+
159
+ lora_a_transposed = [lora_a.transpose(0, 1) for lora_a in padded_lora_a] # [in_features, rank]
160
+ lora_b_transposed = [lora_b.transpose(0, 1) for lora_b in padded_lora_b] # [rank, out_features]
161
+
162
+ self.register_buffer(
163
+ "lora_a_weights", torch.stack(lora_a_transposed, dim=0).to(self.weight.dtype)
164
+ ) # [num_adapters, in_features, rank]
165
+ self.register_buffer(
166
+ "lora_b_weights", torch.stack(lora_b_transposed, dim=0).to(self.weight.dtype)
167
+ ) # [num_adapters, rank, out_features]
168
+
169
+ def forward(self, x: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
170
+ """
171
+ Forward pass that combines base linear transformation with LoRA.
172
+
173
+ Args:
174
+ x: Input tensor [batch_size, seq_len, in_features]
175
+ lora_int_id: Adapter ID tensor [batch_size] indicating which adapter to use
176
+
177
+ Returns:
178
+ Output tensor [batch_size, seq_len, out_features]
179
+ """
180
+ # Base linear transformation
181
+ output = torch.nn.functional.linear(x, self.weight, self.bias)
182
+
183
+ # Apply LoRA if enabled and adapter ID is provided
184
+ if self._should_apply_lora() and lora_int_id is not None:
185
+ # Gather LoRA weights for each batch item
186
+ # lora_int_id: [batch_size] -> use as indices to select weights
187
+ selected_lora_a = self.lora_a_weights[lora_int_id] # [batch_size, in_features, rank]
188
+ selected_lora_b = self.lora_b_weights[lora_int_id] # [batch_size, rank, out_features]
189
+
190
+ # Batched matrix multiplication for LoRA computation
191
+ # x: [batch_size, seq_len, in_features]
192
+ # selected_lora_a: [batch_size, in_features, rank] (already transposed)
193
+ # selected_lora_b: [batch_size, rank, out_features] (already transposed)
194
+
195
+ # First matmul: x @ lora_a -> [batch_size, seq_len, rank]
196
+ temp = torch.bmm(x, selected_lora_a)
197
+
198
+ # Second matmul: temp @ lora_b -> [batch_size, seq_len, out_features]
199
+ lora_delta = torch.bmm(temp, selected_lora_b)
200
+
201
+ # Add LoRA delta to base output
202
+ output = output + lora_delta
203
+
204
+ return output
@@ -375,6 +375,9 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
375
375
  if rbln_config.use_position_ids:
376
376
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
377
377
 
378
+ if rbln_config.use_lora:
379
+ input_info.append(("lora_int_ids", [batch_size], "int32"))
380
+
378
381
  kvcache_dtype = rbln_config.torch_dtype
379
382
  if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
380
383
  kvcache_dtype = "float8_e4m3fn"
@@ -667,6 +670,53 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
667
670
  def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
668
671
  return is_prefill
669
672
 
673
+ def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
674
+ if isinstance(lora_int_ids, int):
675
+ lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
676
+ elif isinstance(lora_int_ids, list):
677
+ lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
678
+
679
+ self.lora_int_ids = lora_int_ids
680
+
681
+ self.prefill_decoder.lora_int_ids = lora_int_ids
682
+ if self.rbln_config.can_generate:
683
+ for batch_size in self.rbln_config.decoder_batch_sizes:
684
+ self.decoders[batch_size].lora_int_ids = lora_int_ids
685
+
686
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
687
+ """
688
+ Sets the active adapter(s) for the model using adapter name(s).
689
+
690
+ Args:
691
+ adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
692
+ Can be a single adapter name or a list of adapter names.
693
+
694
+ Raises:
695
+ ValueError: If the model is not configured with LoRA or if the adapter name is not found.
696
+ """
697
+ if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
698
+ raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
699
+
700
+ # Convert single adapter name to list for uniform processing
701
+ if isinstance(adapter_name, str):
702
+ adapter_names = [adapter_name]
703
+ else:
704
+ adapter_names = adapter_name
705
+
706
+ # Validate that all adapter names exist
707
+ available_adapters = {
708
+ adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
709
+ }
710
+ missing_adapters = [name for name in adapter_names if name not in available_adapters]
711
+ if missing_adapters:
712
+ raise ValueError(
713
+ f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
714
+ )
715
+
716
+ # Get the adapter IDs and set them
717
+ lora_int_ids = [available_adapters[name] for name in adapter_names]
718
+ self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
719
+
670
720
  def forward(
671
721
  self,
672
722
  input_ids: Optional[torch.LongTensor] = None,
@@ -677,6 +727,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
677
727
  padded_cache_lengths: Optional[torch.Tensor] = None,
678
728
  position_ids: Optional[torch.Tensor] = None,
679
729
  token_type_ids: Optional[torch.Tensor] = None,
730
+ lora_int_ids: Optional[torch.Tensor] = None,
680
731
  return_dict: Optional[torch.Tensor] = None,
681
732
  **kwargs,
682
733
  ) -> Tuple[torch.FloatTensor]:
@@ -684,6 +735,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
684
735
  # For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
685
736
  # A for-loop ensures synchronization with the HuggingFace generate API.
686
737
  # The decoder stage operates as usual, processing inputs in batch mode.
738
+ if self.rbln_config.use_lora and lora_int_ids is None:
739
+ if self.lora_int_ids is None:
740
+ raise ValueError(
741
+ "lora_int_id is required when using LoRA. "
742
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
743
+ )
744
+ lora_int_ids = self.lora_int_ids
687
745
 
688
746
  # for only use forward
689
747
  if generate_idx is None:
@@ -708,6 +766,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
708
766
  cache_position=cache_position,
709
767
  batch_idx=b_idx,
710
768
  token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
769
+ lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
711
770
  )
712
771
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
713
772
  logits.append(output.logits)
@@ -727,6 +786,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
727
786
  inputs_embeds=inputs_embeds,
728
787
  cache_position=cache_position,
729
788
  position_ids=position_ids if self.rbln_config.use_position_ids else None,
789
+ lora_int_ids=lora_int_ids,
730
790
  ).logits
731
791
 
732
792
  if not return_dict: