optimum-rbln 0.9.2a3__py3-none-any.whl → 0.9.2a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (34) hide show
  1. optimum/rbln/__init__.py +4 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +3 -0
  4. optimum/rbln/modeling.py +71 -1
  5. optimum/rbln/transformers/__init__.py +4 -0
  6. optimum/rbln/transformers/modeling_generic.py +23 -1
  7. optimum/rbln/transformers/models/__init__.py +4 -0
  8. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +65 -1
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  11. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  12. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  13. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
  14. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  15. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +79 -4
  16. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  17. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  18. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +9 -1
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  20. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -4
  21. optimum/rbln/transformers/models/llava/modeling_llava.py +2 -1
  22. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -1
  23. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  24. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  25. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +0 -9
  26. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
  27. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
  28. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  29. optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
  30. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  31. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/METADATA +5 -5
  32. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/RECORD +34 -32
  33. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/WHEEL +0 -0
  34. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/licenses/LICENSE +0 -0
@@ -187,6 +187,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
187
187
  torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
188
188
  )
189
189
 
190
+ self.lora_int_ids = None
191
+
190
192
  def inputs_embeddings_if_needed(
191
193
  self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
192
194
  ):
@@ -210,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
210
212
  position_ids: Optional[torch.Tensor] = None,
211
213
  token_type_ids: Optional[torch.Tensor] = None,
212
214
  local_block_tables: Optional[torch.Tensor] = None,
215
+ lora_int_ids: Optional[torch.Tensor] = None,
213
216
  ):
214
217
  inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
215
218
  block_tables, local_block_tables, is_external_block_tables = (
@@ -233,6 +236,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
233
236
  position_embed=position_embed,
234
237
  position_ids=position_ids,
235
238
  local_block_tables=local_block_tables,
239
+ lora_int_ids=lora_int_ids,
236
240
  )
237
241
  else:
238
242
  return self.prefill_forward(
@@ -245,6 +249,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
245
249
  position_embed=position_embed,
246
250
  token_type_ids=token_type_ids,
247
251
  local_block_tables=local_block_tables,
252
+ lora_int_ids=lora_int_ids,
248
253
  )
249
254
 
250
255
  def decode_forward(
@@ -257,7 +262,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
257
262
  position_embed: Optional[torch.Tensor] = None,
258
263
  position_ids: Optional[torch.Tensor] = None,
259
264
  local_block_tables: Optional[torch.Tensor] = None,
265
+ lora_int_ids: Optional[torch.Tensor] = None,
260
266
  ) -> torch.FloatTensor:
267
+ if self.rbln_config.use_lora and lora_int_ids is None:
268
+ if self.lora_int_ids is None:
269
+ raise ValueError(
270
+ "lora_int_id is required when using LoRA. "
271
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
272
+ )
273
+
274
+ lora_int_ids = self.lora_int_ids
275
+
276
+ if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
277
+ raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
278
+
261
279
  if self.batch_size != cache_position.shape[0]:
262
280
  raise RuntimeError(
263
281
  f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
@@ -287,6 +305,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
287
305
  position_embed,
288
306
  attention_mask if self.rbln_config.use_attention_mask else None,
289
307
  position_ids if self.rbln_config.use_position_ids else None,
308
+ lora_int_ids if self.rbln_config.use_lora else None,
290
309
  )
291
310
 
292
311
  return RBLNDecoderOnlyOutput(logits=logits)
@@ -369,12 +388,25 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
369
388
  position_embed: Optional[torch.Tensor] = None,
370
389
  token_type_ids: Optional[torch.Tensor] = None,
371
390
  local_block_tables: Optional[torch.Tensor] = None,
391
+ lora_int_ids: Optional[torch.Tensor] = None,
372
392
  ) -> torch.FloatTensor:
373
393
  """
374
394
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
375
395
  Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
376
396
  and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
377
397
  """
398
+ if self.rbln_config.use_lora and lora_int_ids is None:
399
+ if self.lora_int_ids is None:
400
+ raise ValueError(
401
+ "lora_int_id is required when using LoRA. "
402
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
403
+ )
404
+
405
+ if batch_idx is not None:
406
+ lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
407
+ else:
408
+ lora_int_ids = self.lora_int_ids.clone()
409
+
378
410
  (
379
411
  inputs,
380
412
  cache_position,
@@ -426,6 +458,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
426
458
  query_position,
427
459
  chunked_attention_mask if self.rbln_config.use_attention_mask else None,
428
460
  position_ids_chunk,
461
+ lora_int_ids if self.rbln_config.use_lora else None,
429
462
  out=self.out_buffers,
430
463
  )
431
464
  output_logits.append(output_logit)
@@ -0,0 +1,204 @@
1
+ import math
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import safetensors.torch
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ....utils import logging
10
+ from .configuration_lora import RBLNLoRAConfig
11
+
12
+
13
+ logger = logging.get_logger()
14
+
15
+
16
+ class LoRALinear(nn.Module):
17
+ """
18
+ A linear layer that supports multiple LoRA adapters compiled at static time.
19
+
20
+ This class replaces the original linear layer and handles both base weights
21
+ and multiple LoRA adapters in a single forward pass using custom ops.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ original_linear: nn.Linear,
27
+ lora_config: RBLNLoRAConfig,
28
+ projection_name: str = "",
29
+ layer_idx: int = 0,
30
+ ):
31
+ """
32
+ Args:
33
+ original_linear: The original linear layer to be replaced
34
+ lora_config: LoRA configuration containing all adapters
35
+ projection_name: Name of the projection (e.g., "q_proj", "k_proj")
36
+ layer_idx: Layer index for loading the correct LoRA weights
37
+ """
38
+ super().__init__()
39
+
40
+ self.in_features = original_linear.in_features
41
+ self.out_features = original_linear.out_features
42
+ self.projection_name = projection_name
43
+ self.layer_idx = layer_idx
44
+ self.lora_config = lora_config
45
+
46
+ # Store original linear weights and bias directly without cloning
47
+ self.register_buffer("weight", original_linear.weight.data)
48
+ if original_linear.bias is not None:
49
+ self.register_buffer("bias", original_linear.bias.data)
50
+ else:
51
+ self.bias = None
52
+
53
+ # Initialize LoRA weights
54
+ self._init_lora_weights()
55
+
56
+ def _should_apply_lora(self) -> bool:
57
+ """Check if this projection should have LoRA applied."""
58
+ # Check if any adapter targets this projection
59
+ return any(self.projection_name in adapter.target_modules for adapter in self.lora_config.adapters)
60
+
61
+ def _load_adapter_weights(self, adapter_path: Path):
62
+ """
63
+ Load adapter weights from local directory.
64
+
65
+ Args:
66
+ adapter_path: Path to local directory containing adapter weights
67
+
68
+ Returns:
69
+ Dictionary containing adapter weights
70
+
71
+ Raises:
72
+ FileNotFoundError: If no adapter weights are found in the directory
73
+ """
74
+ if not adapter_path.is_dir():
75
+ raise ValueError(f"Adapter path must be a directory, got: {adapter_path}")
76
+
77
+ # Try to load weights in order of preference
78
+ weight_files = [
79
+ ("adapter_model.safetensors", lambda p: safetensors.torch.load_file(p)),
80
+ ("adapter_model.bin", lambda p: torch.load(p, map_location="cpu")),
81
+ ("pytorch_model.bin", lambda p: torch.load(p, map_location="cpu")),
82
+ ]
83
+
84
+ for filename, load_fn in weight_files:
85
+ weight_path = adapter_path / filename
86
+ if weight_path.exists():
87
+ return load_fn(weight_path)
88
+
89
+ raise FileNotFoundError(
90
+ f"No adapter weights found in {adapter_path}. "
91
+ f"Expected one of: {', '.join(filename for filename, _ in weight_files)}"
92
+ )
93
+
94
+ def _init_lora_weights(self):
95
+ """Initialize LoRA adapter weights by loading and stacking them."""
96
+
97
+ lora_a_weights = []
98
+ lora_b_weights = []
99
+
100
+ for adapter in self.lora_config.adapters:
101
+ if self.projection_name not in adapter.target_modules:
102
+ # Create zero weights for adapters that don't target this projection
103
+ lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
104
+ lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
105
+ continue
106
+
107
+ adapter_weights = self._load_adapter_weights(adapter.local_adapter_path)
108
+
109
+ # Determine module type from projection name
110
+ attn_projs = {"q_proj", "k_proj", "v_proj", "o_proj"}
111
+ mlp_projs = {"gate_proj", "up_proj", "down_proj"}
112
+ if self.projection_name in attn_projs:
113
+ module_type = "self_attn"
114
+ elif self.projection_name in mlp_projs:
115
+ module_type = "mlp"
116
+ else:
117
+ module_type = "self_attn"
118
+
119
+ layer_key = f"base_model.model.model.layers.{self.layer_idx}.{module_type}.{self.projection_name}"
120
+ lora_a_key = f"{layer_key}.lora_A.weight"
121
+ lora_b_key = f"{layer_key}.lora_B.weight"
122
+
123
+ if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
124
+ # Calculate scaling factor and fold it into lora_b
125
+ scaling = adapter.lora_alpha / adapter.r
126
+ if adapter.use_rslora:
127
+ scaling = scaling / math.sqrt(adapter.r)
128
+ scaling = scaling * adapter.scaling_factor
129
+
130
+ lora_a_weights.append(adapter_weights[lora_a_key])
131
+ # scaling is pre-applied to lora_b_weights
132
+ lora_b_weights.append(adapter_weights[lora_b_key] * scaling)
133
+ else:
134
+ logger.warning(f"No LoRA weights found for {lora_a_key} or {lora_b_key}")
135
+ lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
136
+ lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
137
+
138
+ # Stack weights along adapter dimension
139
+ max_rank = self.lora_config.max_lora_rank
140
+
141
+ # Pad smaller ranks to max_rank
142
+ padded_lora_a = []
143
+ padded_lora_b = []
144
+
145
+ for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
146
+ current_rank = lora_a.shape[0]
147
+ if current_rank < max_rank:
148
+ # Pad with zeros
149
+ padded_a = torch.zeros(max_rank, self.in_features)
150
+ padded_b = torch.zeros(self.out_features, max_rank)
151
+ padded_a[:current_rank] = lora_a
152
+ padded_b[:, :current_rank] = lora_b
153
+ padded_lora_a.append(padded_a)
154
+ padded_lora_b.append(padded_b)
155
+ else:
156
+ padded_lora_a.append(lora_a)
157
+ padded_lora_b.append(lora_b)
158
+
159
+ lora_a_transposed = [lora_a.transpose(0, 1) for lora_a in padded_lora_a] # [in_features, rank]
160
+ lora_b_transposed = [lora_b.transpose(0, 1) for lora_b in padded_lora_b] # [rank, out_features]
161
+
162
+ self.register_buffer(
163
+ "lora_a_weights", torch.stack(lora_a_transposed, dim=0).to(self.weight.dtype)
164
+ ) # [num_adapters, in_features, rank]
165
+ self.register_buffer(
166
+ "lora_b_weights", torch.stack(lora_b_transposed, dim=0).to(self.weight.dtype)
167
+ ) # [num_adapters, rank, out_features]
168
+
169
+ def forward(self, x: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
170
+ """
171
+ Forward pass that combines base linear transformation with LoRA.
172
+
173
+ Args:
174
+ x: Input tensor [batch_size, seq_len, in_features]
175
+ lora_int_id: Adapter ID tensor [batch_size] indicating which adapter to use
176
+
177
+ Returns:
178
+ Output tensor [batch_size, seq_len, out_features]
179
+ """
180
+ # Base linear transformation
181
+ output = torch.nn.functional.linear(x, self.weight, self.bias)
182
+
183
+ # Apply LoRA if enabled and adapter ID is provided
184
+ if self._should_apply_lora() and lora_int_id is not None:
185
+ # Gather LoRA weights for each batch item
186
+ # lora_int_id: [batch_size] -> use as indices to select weights
187
+ selected_lora_a = self.lora_a_weights[lora_int_id] # [batch_size, in_features, rank]
188
+ selected_lora_b = self.lora_b_weights[lora_int_id] # [batch_size, rank, out_features]
189
+
190
+ # Batched matrix multiplication for LoRA computation
191
+ # x: [batch_size, seq_len, in_features]
192
+ # selected_lora_a: [batch_size, in_features, rank] (already transposed)
193
+ # selected_lora_b: [batch_size, rank, out_features] (already transposed)
194
+
195
+ # First matmul: x @ lora_a -> [batch_size, seq_len, rank]
196
+ temp = torch.bmm(x, selected_lora_a)
197
+
198
+ # Second matmul: temp @ lora_b -> [batch_size, seq_len, out_features]
199
+ lora_delta = torch.bmm(temp, selected_lora_b)
200
+
201
+ # Add LoRA delta to base output
202
+ output = output + lora_delta
203
+
204
+ return output
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
19
19
  import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
- from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
23
23
  from transformers.modeling_outputs import BaseModelOutputWithPast
24
24
  from transformers.modeling_utils import no_init_weights
25
25
 
@@ -317,12 +317,27 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
317
317
 
318
318
  @classmethod
319
319
  def get_pytorch_model(
320
- cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
320
+ cls,
321
+ model_id: str,
322
+ *args,
323
+ rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
324
+ num_hidden_layers: Optional[int] = None,
325
+ **kwargs,
321
326
  ) -> PreTrainedModel:
322
327
  if rbln_config and rbln_config.quantization:
323
- model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
328
+ model = cls.get_quantized_model(model_id, *args, rbln_config=rbln_config, **kwargs)
324
329
  else:
325
- model = super().get_pytorch_model(*args, **kwargs)
330
+ if num_hidden_layers is not None:
331
+ trust_remote_code = kwargs.get("trust_remote_code", None)
332
+ config, kwargs = AutoConfig.from_pretrained(
333
+ model_id, return_unused_kwargs=True, num_hidden_layers=num_hidden_layers, **kwargs
334
+ )
335
+ if hasattr(config, "layer_types"):
336
+ config.layer_types = config.layer_types[:num_hidden_layers]
337
+ kwargs["config"] = config
338
+ kwargs["trust_remote_code"] = trust_remote_code
339
+
340
+ model = super().get_pytorch_model(model_id, *args, **kwargs)
326
341
 
327
342
  return model
328
343
 
@@ -375,6 +390,9 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
375
390
  if rbln_config.use_position_ids:
376
391
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
377
392
 
393
+ if rbln_config.use_lora:
394
+ input_info.append(("lora_int_ids", [batch_size], "int32"))
395
+
378
396
  kvcache_dtype = rbln_config.torch_dtype
379
397
  if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
380
398
  kvcache_dtype = "float8_e4m3fn"
@@ -667,6 +685,53 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
667
685
  def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
668
686
  return is_prefill
669
687
 
688
+ def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
689
+ if isinstance(lora_int_ids, int):
690
+ lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
691
+ elif isinstance(lora_int_ids, list):
692
+ lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
693
+
694
+ self.lora_int_ids = lora_int_ids
695
+
696
+ self.prefill_decoder.lora_int_ids = lora_int_ids
697
+ if self.rbln_config.can_generate:
698
+ for batch_size in self.rbln_config.decoder_batch_sizes:
699
+ self.decoders[batch_size].lora_int_ids = lora_int_ids
700
+
701
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
702
+ """
703
+ Sets the active adapter(s) for the model using adapter name(s).
704
+
705
+ Args:
706
+ adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
707
+ Can be a single adapter name or a list of adapter names.
708
+
709
+ Raises:
710
+ ValueError: If the model is not configured with LoRA or if the adapter name is not found.
711
+ """
712
+ if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
713
+ raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
714
+
715
+ # Convert single adapter name to list for uniform processing
716
+ if isinstance(adapter_name, str):
717
+ adapter_names = [adapter_name]
718
+ else:
719
+ adapter_names = adapter_name
720
+
721
+ # Validate that all adapter names exist
722
+ available_adapters = {
723
+ adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
724
+ }
725
+ missing_adapters = [name for name in adapter_names if name not in available_adapters]
726
+ if missing_adapters:
727
+ raise ValueError(
728
+ f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
729
+ )
730
+
731
+ # Get the adapter IDs and set them
732
+ lora_int_ids = [available_adapters[name] for name in adapter_names]
733
+ self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
734
+
670
735
  def forward(
671
736
  self,
672
737
  input_ids: Optional[torch.LongTensor] = None,
@@ -677,6 +742,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
677
742
  padded_cache_lengths: Optional[torch.Tensor] = None,
678
743
  position_ids: Optional[torch.Tensor] = None,
679
744
  token_type_ids: Optional[torch.Tensor] = None,
745
+ lora_int_ids: Optional[torch.Tensor] = None,
680
746
  return_dict: Optional[torch.Tensor] = None,
681
747
  **kwargs,
682
748
  ) -> Tuple[torch.FloatTensor]:
@@ -684,6 +750,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
684
750
  # For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
685
751
  # A for-loop ensures synchronization with the HuggingFace generate API.
686
752
  # The decoder stage operates as usual, processing inputs in batch mode.
753
+ if self.rbln_config.use_lora and lora_int_ids is None:
754
+ if self.lora_int_ids is None:
755
+ raise ValueError(
756
+ "lora_int_id is required when using LoRA. "
757
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
758
+ )
759
+ lora_int_ids = self.lora_int_ids
687
760
 
688
761
  # for only use forward
689
762
  if generate_idx is None:
@@ -708,6 +781,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
708
781
  cache_position=cache_position,
709
782
  batch_idx=b_idx,
710
783
  token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
784
+ lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
711
785
  )
712
786
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
713
787
  logits.append(output.logits)
@@ -727,6 +801,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
727
801
  inputs_embeds=inputs_embeds,
728
802
  cache_position=cache_position,
729
803
  position_ids=position_ids if self.rbln_config.use_position_ids else None,
804
+ lora_int_ids=lora_int_ids,
730
805
  ).logits
731
806
 
732
807
  if not return_dict:
@@ -63,6 +63,7 @@ class Gemma3TextModel(DecoderOnlyModel):
63
63
  rotary_emb: torch.nn.Module = None,
64
64
  global_block_tables: Optional[torch.Tensor] = None,
65
65
  local_block_tables: Optional[torch.Tensor] = None,
66
+ lora_int_id: Optional[torch.Tensor] = None,
66
67
  ):
67
68
  # retrieve input_ids and inputs_embeds
68
69
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -105,6 +106,7 @@ class Gemma3TextModel(DecoderOnlyModel):
105
106
  cos=cos_local if is_sliding else cos_global,
106
107
  sin=sin_local if is_sliding else sin_global,
107
108
  block_tables=local_block_tables if is_sliding else global_block_tables,
109
+ lora_int_id=lora_int_id,
108
110
  )
109
111
 
110
112
  hidden_states = self.get_last_layernorm()(hidden_states)
@@ -127,12 +129,20 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
127
129
  cos: Optional[torch.Tensor] = None,
128
130
  sin: Optional[torch.Tensor] = None,
129
131
  block_tables: Optional[torch.Tensor] = None,
132
+ lora_int_id: Optional[torch.Tensor] = None,
130
133
  ):
131
134
  residual = hidden_states
132
135
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
133
136
 
134
137
  hidden_states = self.self_attn(
135
- hidden_states, attention_mask, seq_positions, past_key_values, cos, sin, block_tables
138
+ hidden_states=hidden_states,
139
+ attention_mask=attention_mask,
140
+ seq_positions=seq_positions,
141
+ past_key_values=past_key_values,
142
+ cos=cos,
143
+ sin=sin,
144
+ block_tables=block_tables,
145
+ lora_int_id=lora_int_id,
136
146
  )
137
147
  hidden_states = self.get_post_attention_layernorm()(hidden_states)
138
148
  hidden_states = residual + hidden_states
@@ -140,7 +150,7 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
140
150
  # Fully Connected
141
151
  residual = hidden_states
142
152
  hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
143
- hidden_states = self._original_mod.mlp(hidden_states)
153
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
144
154
  hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
145
155
  hidden_states = residual + hidden_states
146
156
 
@@ -17,15 +17,16 @@ import rebel
17
17
  import torch
18
18
 
19
19
  from ...modeling_outputs import RBLNDecoderOnlyOutput, RBLNGemma3ForCausalLMOutput
20
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
20
21
  from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
21
22
 
22
23
 
23
24
  class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
24
25
  def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
25
26
  super().__init__(*args, **kwargs)
26
- self.image_prefill = image_prefill # FIXME(taehoon)
27
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
28
- self.decode = self.runtime if self.phase == "decode" else None
27
+ self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
28
+ self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
29
+ self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
29
30
 
30
31
  def _prepare_prefill_inputs(self, *args, **kwargs):
31
32
  (
@@ -73,12 +74,24 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
73
74
  position_embed: Optional[torch.Tensor] = None,
74
75
  token_type_ids: Optional[torch.Tensor] = None,
75
76
  local_block_tables: Optional[torch.Tensor] = None,
77
+ lora_int_ids: Optional[torch.Tensor] = None,
76
78
  ) -> torch.FloatTensor:
77
79
  """
78
80
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
79
81
  Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
80
82
  and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
81
83
  """
84
+ if self.rbln_config.use_lora and lora_int_ids is None:
85
+ if self.lora_int_ids is None:
86
+ raise ValueError(
87
+ "lora_int_id is required when using LoRA. "
88
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
89
+ )
90
+ if batch_idx is not None:
91
+ lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
92
+ else:
93
+ lora_int_ids = self.lora_int_ids.clone()
94
+
82
95
  (
83
96
  inputs,
84
97
  cache_position,
@@ -141,6 +154,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
141
154
  query_position,
142
155
  chunked_attention_mask,
143
156
  position_ids_chunk,
157
+ lora_int_ids if self.rbln_config.use_lora else None,
144
158
  )
145
159
  else:
146
160
  logits = self.prefill(
@@ -151,6 +165,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
151
165
  query_position,
152
166
  chunked_attention_mask,
153
167
  position_ids_chunk,
168
+ lora_int_ids if self.rbln_config.use_lora else None,
154
169
  )
155
170
 
156
171
  padded_cache_lengths += current_padded_cache_lengths
@@ -173,7 +188,20 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
173
188
  position_embed: Optional[torch.Tensor] = None,
174
189
  position_ids: Optional[torch.Tensor] = None,
175
190
  local_block_tables: Optional[torch.Tensor] = None,
191
+ lora_int_ids: Optional[torch.Tensor] = None,
176
192
  ) -> torch.FloatTensor:
193
+ if self.rbln_config.use_lora and lora_int_ids is None:
194
+ if self.lora_int_ids is None:
195
+ raise ValueError(
196
+ "lora_int_id is required when using LoRA. "
197
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
198
+ )
199
+
200
+ lora_int_ids = self.lora_int_ids
201
+
202
+ if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
203
+ raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
204
+
177
205
  batch_size = inputs.shape[0]
178
206
  if batch_size != self.batch_size:
179
207
  raise RuntimeError(
@@ -28,6 +28,7 @@ from ....modeling import RBLNModel
28
28
  from ...modeling_outputs import RBLNDecoderOnlyOutput
29
29
  from ...utils.rbln_runtime_wrapper import LoopProcessor
30
30
  from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
31
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
31
32
  from ..decoderonly.modeling_decoderonly import (
32
33
  RBLNDecoderOnlyModelForCausalLM,
33
34
  )
@@ -77,7 +78,7 @@ class LoopProjector(LoopProcessor):
77
78
  return output[0]
78
79
 
79
80
 
80
- class RBLNGemma3ForConditionalGeneration(RBLNModel):
81
+ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
81
82
  auto_model_class = AutoModelForImageTextToText
82
83
  _rbln_submodules = [
83
84
  {"name": "vision_tower"},
@@ -408,6 +409,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
408
409
  def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
409
410
  sliding_window = getattr(model_config, "sliding_window", None)
410
411
  sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
412
+ if sliding_window_pattern is None:
413
+ if hasattr(model_config, "layer_types"):
414
+ first_full_attention_index = model_config.layer_types.index("full_attention")
415
+ sliding_window_pattern = first_full_attention_index + 1
416
+ else:
417
+ raise ValueError("Cannot determine sliding_window_pattern from model_config")
418
+
411
419
  if sliding_window_pattern <= model_config.num_hidden_layers:
412
420
  rbln_config.cache_impl = "hybrid"
413
421
  rbln_config.sliding_window = sliding_window
@@ -75,7 +75,10 @@ class GPT2Attention(DecoderOnlyAttention):
75
75
  self.o_proj = self._original_mod.c_proj
76
76
  self.split_size = self._original_mod.split_size
77
77
 
78
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
78
+ def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
79
+ if lora_int_id is not None:
80
+ raise NotImplementedError("LoRA is not supported for GPT2Attention")
81
+
79
82
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
80
83
  return query_states, key_states, value_states
81
84
 
@@ -35,6 +35,7 @@ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
35
35
  from ....modeling import RBLNModel
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
37
  from ...modeling_outputs import RBLNDecoderOnlyOutput
38
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
38
39
 
39
40
 
40
41
  if TYPE_CHECKING:
@@ -120,9 +121,6 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
120
121
  encoder_outputs = self.encoder(
121
122
  inputs_embeds=hidden_states,
122
123
  attention_mask=patch_attention_mask,
123
- output_attentions=None,
124
- output_hidden_states=None,
125
- return_dict=False,
126
124
  )
127
125
  last_hidden_state = encoder_outputs[0]
128
126
  last_hidden_state = self.post_layernorm(last_hidden_state)
@@ -185,7 +183,7 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
185
183
  return BaseModelOutput(last_hidden_state=last_hidden_state)
186
184
 
187
185
 
188
- class RBLNIdefics3ForConditionalGeneration(RBLNModel):
186
+ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
189
187
  """
190
188
  RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
191
189
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -27,6 +27,7 @@ from ....modeling import RBLNModel
27
27
  from ....utils.logging import get_logger
28
28
  from ...modeling_outputs import RBLNDecoderOnlyOutput
29
29
  from ...utils.rbln_runtime_wrapper import LoopProcessor
30
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
30
31
 
31
32
 
32
33
  logger = get_logger(__name__)
@@ -103,7 +104,7 @@ class LoopProjector(LoopProcessor):
103
104
  return output[0]
104
105
 
105
106
 
106
- class RBLNLlavaForConditionalGeneration(RBLNModel):
107
+ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
107
108
  """
108
109
  RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
109
110
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -32,6 +32,7 @@ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
32
32
  from ....modeling import RBLNModel
33
33
  from ....utils.logging import get_logger
34
34
  from ...utils.rbln_runtime_wrapper import LoopProcessor
35
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
35
36
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
36
37
 
37
38
 
@@ -87,7 +88,7 @@ class LoopProjector(LoopProcessor):
87
88
  return output[0]
88
89
 
89
90
 
90
- class RBLNLlavaNextForConditionalGeneration(RBLNModel):
91
+ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
91
92
  """
92
93
  RBLNLlavaNextForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
93
94
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.