optimum-rbln 0.9.2a3__py3-none-any.whl → 0.9.2a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (34) hide show
  1. optimum/rbln/__init__.py +4 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +3 -0
  4. optimum/rbln/modeling.py +71 -1
  5. optimum/rbln/transformers/__init__.py +4 -0
  6. optimum/rbln/transformers/modeling_generic.py +23 -1
  7. optimum/rbln/transformers/models/__init__.py +4 -0
  8. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +65 -1
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  11. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  12. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  13. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
  14. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  15. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +79 -4
  16. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  17. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  18. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +9 -1
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  20. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -4
  21. optimum/rbln/transformers/models/llava/modeling_llava.py +2 -1
  22. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -1
  23. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  24. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  25. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +0 -9
  26. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
  27. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
  28. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  29. optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
  30. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  31. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/METADATA +5 -5
  32. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/RECORD +34 -32
  33. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/WHEEL +0 -0
  34. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,411 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ from huggingface_hub import snapshot_download
6
+
7
+ from ....configuration_utils import RBLNSerializableConfigProtocol
8
+ from ....utils.logging import get_logger
9
+
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class RBLNLoRAAdapterConfig(RBLNSerializableConfigProtocol):
15
+ """
16
+ Configuration class for individual LoRA adapter settings.
17
+
18
+ This class represents a single LoRA adapter that will be compiled into the RBLN model.
19
+ Since RBLN NPU requires all adapters to be determined at compile time, each adapter
20
+ must be fully specified including its weights.
21
+
22
+ Examples:
23
+ ```python
24
+ from transformers import AutoTokenizer
25
+
26
+ from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig, RBLNLoRAAdapterConfig, RBLNLoRAConfig
27
+
28
+
29
+ model_id = "meta-llama/Llama-3.1-8B-Instruct"
30
+ lora_ids = [
31
+ "nvidia/llama-3.1-nemoguard-8b-topic-control",
32
+ "reissbaker/llama-3.1-8b-abliterated-lora",
33
+ ]
34
+ prompt = "What are the safety considerations for AI systems?"
35
+ tp_size = 4
36
+
37
+ # adapter id should be higher than 0
38
+ # 0 is reserved for base model
39
+ lora_config = RBLNLoRAConfig(
40
+ adapters=[
41
+ RBLNLoRAAdapterConfig(1, "nemoguard", lora_ids[0]),
42
+ RBLNLoRAAdapterConfig(2, "abliterated", lora_ids[1]),
43
+ ],
44
+ )
45
+
46
+ model = RBLNLlamaForCausalLM.from_pretrained(
47
+ model_id,
48
+ rbln_config=RBLNLlamaForCausalLMConfig(lora_config=lora_config, tensor_parallel_size=tp_size, max_seq_len=8192),
49
+ torch_dtype="auto",
50
+ )
51
+
52
+
53
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
54
+ tokenizer.pad_token = tokenizer.eos_token
55
+
56
+
57
+ prompt_template = tokenizer.apply_chat_template(
58
+ [
59
+ {"role": "system", "content": "You are a helpful assistant. Always be concise"},
60
+ {"role": "user", "content": prompt},
61
+ ],
62
+ add_generation_prompt=True,
63
+ tokenize=False,
64
+ )
65
+ inputs = tokenizer([prompt_template], return_tensors="pt")
66
+ input_len = inputs["input_ids"].shape[-1]
67
+
68
+ for adapter_name in lora_config.adapter_names:
69
+ model.set_adapter(adapter_name)
70
+ decoder_outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
71
+ generated_text = tokenizer.decode(decoder_outputs[0][input_len:], skip_special_tokens=True)
72
+ print(generated_text + "\n")
73
+ ```
74
+
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ lora_int_id: int,
80
+ lora_name: str,
81
+ lora_path: Union[str, Path],
82
+ r: Optional[int] = None,
83
+ lora_alpha: Optional[float] = None,
84
+ target_modules: Optional[List[str]] = None,
85
+ bias: Optional[str] = None,
86
+ use_rslora: Optional[bool] = None,
87
+ scaling_factor: Optional[float] = None,
88
+ ):
89
+ """
90
+ Args:
91
+ lora_int_id (int): Unique identifier for this LoRA adapter (e.g., 0, 1, 2).
92
+ This ID will be used during runtime to select which adapter to use.
93
+ lora_name (str): Human-readable name for this adapter (e.g., "math_tuned", "code_tuned").
94
+ lora_path (Union[str, Path]): Path to the LoRA adapter weights directory or file.
95
+ Must be accessible at compile time to load the weights.
96
+ r (Optional[int]): The rank of the LoRA approximation for this adapter. If None,
97
+ will be loaded from adapter config file.
98
+ lora_alpha (Optional[float]): The LoRA scaling parameter for this adapter. If None,
99
+ will be loaded from adapter config file.
100
+ target_modules (Optional[List[str]]): List of module names to apply LoRA to.
101
+ If None, will be loaded from adapter config file or inherit from parent RBLNLoRAConfig.
102
+ bias (Optional[str]): Bias handling strategy. Options: "none", "all", "lora_only".
103
+ If None, will be loaded from adapter config file.
104
+ use_rslora (Optional[bool]): Whether to use Rank-Stabilized LoRA. If None,
105
+ will be loaded from adapter config file.
106
+ scaling_factor (Optional[float]): Additional scaling factor for this adapter. Defaults to 1.0.
107
+ **kwargs: Additional adapter-specific arguments.
108
+
109
+ Raises:
110
+ ValueError: If lora_int_id is None.
111
+ ValueError: If lora_path doesn't exist.
112
+ ValueError: If r is not a positive integer.
113
+ ValueError: If lora_alpha is not positive.
114
+ """
115
+ if lora_int_id is None:
116
+ raise ValueError("lora_int_id cannot be None")
117
+
118
+ if not isinstance(lora_int_id, int):
119
+ raise ValueError(f"lora_int_id must be an integer, got {type(lora_int_id)}")
120
+
121
+ self.lora_int_id = lora_int_id
122
+ self.lora_name = lora_name
123
+
124
+ # Keep original lora_path as provided by user (for serialization)
125
+ self.lora_path = Path(lora_path)
126
+
127
+ # Resolve to local directory path (for actual weight loading)
128
+ self.local_adapter_path = self._resolve_adapter_path(self.lora_path)
129
+
130
+ # Load adapter config and use as defaults
131
+ adapter_config = self._load_adapter_config()
132
+
133
+ # Set values from adapter config if not explicitly provided
134
+ self.r = r if r is not None else adapter_config.get("r", 8)
135
+ self.lora_alpha = lora_alpha if lora_alpha is not None else adapter_config.get("lora_alpha", 8.0)
136
+ self.target_modules = (
137
+ target_modules if target_modules is not None else adapter_config.get("target_modules", None)
138
+ )
139
+ self.bias = bias if bias is not None else adapter_config.get("bias", "none")
140
+ if self.bias not in ["none"]:
141
+ raise NotImplementedError("bias != 'none' is not supported yet")
142
+
143
+ self.use_rslora = use_rslora if use_rslora is not None else adapter_config.get("use_rslora", False)
144
+ self.scaling_factor = scaling_factor if scaling_factor is not None else 1.0
145
+
146
+ # Validate the final values
147
+ if not isinstance(self.r, int) or self.r <= 0:
148
+ raise ValueError(f"r must be a positive integer, got {self.r}")
149
+
150
+ if self.lora_alpha <= 0:
151
+ raise ValueError(f"lora_alpha must be positive, got {self.lora_alpha}")
152
+
153
+ if self.bias not in ["none", "all", "lora_only"]:
154
+ raise ValueError(f"bias must be one of ['none', 'all', 'lora_only'], got {self.bias}")
155
+
156
+ def _resolve_adapter_path(self, path: Path) -> Path:
157
+ """
158
+ Resolve the adapter path, downloading from HuggingFace Hub if necessary.
159
+
160
+ Args:
161
+ path: Local path or HuggingFace Hub model ID
162
+
163
+ Returns:
164
+ Path object pointing to local adapter directory
165
+
166
+ Raises:
167
+ ValueError: If the adapter cannot be found locally or downloaded
168
+ """
169
+ # If it's a local path and exists, return it
170
+ if path.exists():
171
+ return path
172
+
173
+ # If it's an absolute path that doesn't exist, raise error
174
+ if path.is_absolute():
175
+ raise ValueError(f"LoRA adapter path does not exist: {path.as_posix()}")
176
+
177
+ # Try to interpret as HuggingFace Hub model ID and download
178
+ try:
179
+ local_dir = snapshot_download(str(path), allow_patterns=["*.safetensors", "*.bin", "*.json"])
180
+ return Path(local_dir)
181
+ except Exception as e:
182
+ raise ValueError(
183
+ f"Failed to download LoRA adapter '{path.as_posix()}' from HuggingFace Hub. "
184
+ f"Please check if the model ID is correct or provide a valid local path. "
185
+ f"Error: {e}"
186
+ )
187
+
188
+ def _load_adapter_config(self) -> Dict[str, Any]:
189
+ """
190
+ Load adapter configuration from adapter_config.json file.
191
+
192
+ Returns:
193
+ Dictionary containing adapter configuration
194
+
195
+ Raises:
196
+ ValueError: If adapter_config.json is not found or cannot be parsed
197
+ """
198
+ config_path = self.local_adapter_path / "adapter_config.json"
199
+
200
+ if not config_path.exists():
201
+ logger.warning(f"No adapter_config.json found at {config_path}, using default values")
202
+ return {}
203
+
204
+ try:
205
+ with open(config_path, "r", encoding="utf-8") as f:
206
+ adapter_config = json.load(f)
207
+ logger.info(f"Loaded adapter config from {config_path}")
208
+ return adapter_config
209
+ except Exception as e:
210
+ logger.warning(f"Failed to load adapter config from {config_path}: {e}, using default values")
211
+ return {}
212
+
213
+ def _prepare_for_serialization(self) -> Dict[str, Any]:
214
+ config_dict = {
215
+ "lora_int_id": self.lora_int_id,
216
+ "lora_name": self.lora_name,
217
+ "lora_path": str(self.lora_path),
218
+ "r": self.r,
219
+ "lora_alpha": self.lora_alpha,
220
+ "target_modules": self.target_modules,
221
+ "bias": self.bias,
222
+ "use_rslora": self.use_rslora,
223
+ "scaling_factor": self.scaling_factor,
224
+ }
225
+ return config_dict
226
+
227
+
228
+ class RBLNLoRABaseAdapterConfig(RBLNLoRAAdapterConfig):
229
+ """
230
+ Special adapter config for the reserved base model adapter (lora_int_id = 0).
231
+ This adapter carries zero-effective LoRA weights by targeting no modules,
232
+ thereby producing no LoRA delta and yielding pure base-model behavior.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ lora_int_id: int = 0,
238
+ lora_name: str = "base",
239
+ lora_path: Union[str, Path] = "__reserved_base__",
240
+ r: Optional[int] = 1,
241
+ lora_alpha: Optional[float] = 1.0,
242
+ target_modules: Optional[List[str]] = None,
243
+ bias: Optional[str] = "none",
244
+ use_rslora: Optional[bool] = False,
245
+ scaling_factor: Optional[float] = 1.0,
246
+ ):
247
+ if lora_int_id != 0:
248
+ raise ValueError("RBLNLoRABaseAdapterConfig must have lora_int_id=0")
249
+
250
+ self.lora_int_id = 0
251
+ self.lora_name = lora_name
252
+ # Keep original lora_path for serialization purposes but do not resolve it.
253
+ self.lora_path = Path(str(lora_path))
254
+ self.local_adapter_path = None
255
+
256
+ # Set minimal defaults; target_modules empty disables LoRA on all projections
257
+ self.r = 1 if r is None else r
258
+ self.lora_alpha = 1.0 if lora_alpha is None else lora_alpha
259
+ self.target_modules = []
260
+ self.bias = "none"
261
+ self.use_rslora = False
262
+ self.scaling_factor = 1.0
263
+
264
+ # Validate minimal settings
265
+ if not isinstance(self.r, int) or self.r <= 0:
266
+ raise ValueError(f"r must be a positive integer, got {self.r}")
267
+ if self.lora_alpha <= 0:
268
+ raise ValueError(f"lora_alpha must be positive, got {self.lora_alpha}")
269
+
270
+
271
+ class RBLNLoRAConfig(RBLNSerializableConfigProtocol):
272
+ """
273
+ Configuration class for multi-LoRA support in RBLN decoder-only models.
274
+
275
+ This class manages all LoRA adapters that will be compiled into the RBLN model.
276
+ Since RBLN NPU requires all adapters to be determined at compile time, this
277
+ configuration must specify all adapters upfront with their weights.
278
+
279
+ Key constraints for RBLN multi-LoRA:
280
+ 1. All LoRA adapters must be specified at compile time
281
+ 2. Adapter weights must be available during compilation
282
+ 3. The number of adapters is fixed after compilation
283
+ 4. Runtime can only switch between pre-compiled adapters
284
+ """
285
+
286
+ def __init__(
287
+ self, adapters: List[Union[Dict[str, Any], RBLNLoRAAdapterConfig]], max_lora_rank: Optional[int] = None
288
+ ):
289
+ """
290
+ Args:
291
+ adapters (List[Union[Dict[str, Any], RBLNLoRAAdapterConfig]]): List of LoRA adapters
292
+ to be compiled into the model. Each adapter must be fully specified with weights
293
+ accessible at compile time.
294
+ max_lora_rank (Optional[int]): Maximum rank across all adapters. If None, automatically
295
+ determined from the provided adapters. Used for memory allocation optimization.
296
+
297
+ Raises:
298
+ ValueError: If adapters list is empty.
299
+ ValueError: If adapter IDs are not unique.
300
+ ValueError: If any adapter path doesn't exist.
301
+ """
302
+ if not adapters:
303
+ raise ValueError("adapters list cannot be empty")
304
+
305
+ # Convert dict adapters to RBLNLoRAAdapterConfig objects
306
+ self.adapters: List[RBLNLoRAAdapterConfig] = []
307
+ for adapter in adapters:
308
+ if isinstance(adapter, dict):
309
+ self.adapters.append(RBLNLoRAAdapterConfig(**adapter))
310
+ elif isinstance(adapter, RBLNLoRAAdapterConfig):
311
+ self.adapters.append(adapter)
312
+ else:
313
+ raise ValueError(f"Invalid adapter type: {type(adapter)}")
314
+
315
+ # Disallow user-provided adapter with id 0: it's reserved for base model
316
+ if any(ad.lora_int_id == 0 for ad in self.adapters):
317
+ raise ValueError(
318
+ "lora_int_id=0 is reserved for base model and cannot be provided. Please renumber your adapters to start from 1."
319
+ )
320
+
321
+ # Inject a reserved zero-weight adapter for base model at id=0
322
+ base_adapter = RBLNLoRABaseAdapterConfig()
323
+ self.adapters.insert(0, base_adapter)
324
+
325
+ # Sort adapters by ID to make IDs align with indices
326
+ self.adapters.sort(key=lambda a: a.lora_int_id)
327
+
328
+ # Validate unique and contiguous adapter IDs starting from 0
329
+ adapter_ids = [adapter.lora_int_id for adapter in self.adapters]
330
+ if len(adapter_ids) != len(set(adapter_ids)):
331
+ raise ValueError("All adapter IDs must be unique")
332
+ expected_ids = list(range(len(self.adapters)))
333
+ if adapter_ids != expected_ids:
334
+ raise ValueError(
335
+ f"Adapter IDs must be contiguous and start from 0. Found {adapter_ids}, expected {expected_ids}."
336
+ )
337
+
338
+ # Calculate max_lora_rank if not provided
339
+ if max_lora_rank is None:
340
+ self.max_lora_rank = max(adapter.r for adapter in self.adapters)
341
+ else:
342
+ self.max_lora_rank = max_lora_rank
343
+ # Validate that max_lora_rank is sufficient
344
+ actual_max_rank = max(adapter.r for adapter in self.adapters)
345
+ if self.max_lora_rank < actual_max_rank:
346
+ raise ValueError(
347
+ f"max_lora_rank ({self.max_lora_rank}) must be >= actual max rank ({actual_max_rank})"
348
+ )
349
+
350
+ @property
351
+ def num_adapters(self) -> int:
352
+ return len(self.adapters)
353
+
354
+ @property
355
+ def adapter_ids(self) -> List[int]:
356
+ return [adapter.lora_int_id for adapter in self.adapters]
357
+
358
+ @property
359
+ def adapter_names(self) -> List[str]:
360
+ return [adapter.lora_name for adapter in self.adapters]
361
+
362
+ def get_adapter_by_id(self, lora_int_id: int) -> Optional[RBLNLoRAAdapterConfig]:
363
+ for adapter in self.adapters:
364
+ if adapter.lora_int_id == lora_int_id:
365
+ return adapter
366
+ return None
367
+
368
+ def get_adapter_by_name(self, lora_name: str) -> Optional[RBLNLoRAAdapterConfig]:
369
+ for adapter in self.adapters:
370
+ if adapter.lora_name == lora_name:
371
+ return adapter
372
+ return None
373
+
374
+ def validate_adapter_weights(self) -> Dict[int, bool]:
375
+ validation_results = {}
376
+ for adapter in self.adapters:
377
+ try:
378
+ # The reserved base adapter (id=0) always validates to True
379
+ if adapter.lora_int_id == 0:
380
+ validation_results[adapter.lora_int_id] = True
381
+ continue
382
+ # Check if adapter path exists and contains expected files
383
+ adapter_path = adapter.local_adapter_path
384
+ if adapter_path is not None and adapter_path.is_file():
385
+ # Single file adapter (e.g., safetensors)
386
+ validation_results[adapter.lora_int_id] = adapter_path.exists()
387
+ else:
388
+ # Directory adapter - check for common LoRA files
389
+ expected_files = ["adapter_model.safetensors", "adapter_config.json"]
390
+ alternative_files = ["pytorch_model.bin", "adapter_model.bin"]
391
+
392
+ has_weights = adapter_path is not None and any(
393
+ (adapter_path / f).exists() for f in expected_files + alternative_files
394
+ )
395
+ has_config = adapter_path is not None and (adapter_path / "adapter_config.json").exists()
396
+
397
+ validation_results[adapter.lora_int_id] = has_weights and has_config
398
+ except Exception as e:
399
+ logger.warning(f"Failed to validate adapter {adapter.lora_int_id}: {e}")
400
+ validation_results[adapter.lora_int_id] = False
401
+
402
+ return validation_results
403
+
404
+ def _prepare_for_serialization(self) -> Dict[str, Any]:
405
+ # Do not serialize the reserved base adapter (id=0)
406
+ serializable_adapters = [adapter for adapter in self.adapters if adapter.lora_int_id != 0]
407
+ serializable_map = {
408
+ "adapters": [adapter._prepare_for_serialization() for adapter in serializable_adapters],
409
+ "max_lora_rank": self.max_lora_rank,
410
+ }
411
+ return serializable_map
@@ -22,6 +22,8 @@ from transformers import PretrainedConfig, PreTrainedModel
22
22
  from ....utils import logging
23
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
24
  from ...utils.rbln_quantization import RBLNQuantizationConfig
25
+ from .configuration_lora import RBLNLoRAConfig
26
+ from .lora_architecture import LoRALinear
25
27
 
26
28
 
27
29
  if TYPE_CHECKING:
@@ -52,12 +54,7 @@ class DecoderOnlyWrapper(nn.Module):
52
54
 
53
55
  _use_learned_pos_emb = False
54
56
 
55
- def __init__(
56
- self,
57
- model: PreTrainedModel,
58
- rbln_config: "RBLNDecoderOnlyModelConfig",
59
- use_rotary_emb: bool,
60
- ):
57
+ def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
61
58
  super().__init__()
62
59
  self.quantization = rbln_config.quantization
63
60
  self.config = model.config
@@ -114,7 +111,7 @@ class DecoderOnlyWrapper(nn.Module):
114
111
  new_self_attn = self.get_rbln_attn_class()(
115
112
  self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
116
113
  )
117
- new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
114
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
118
115
  new_layers.append(new_layer)
119
116
 
120
117
  new_model = self.get_rbln_model_class()(
@@ -154,6 +151,7 @@ class DecoderOnlyWrapper(nn.Module):
154
151
  )
155
152
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
156
153
  position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
154
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
157
155
  past_key_values = args
158
156
 
159
157
  if len(past_key_values) != 2 * self.num_hidden_layers:
@@ -185,6 +183,7 @@ class DecoderOnlyWrapper(nn.Module):
185
183
  query_position,
186
184
  attention_mask,
187
185
  position_ids,
186
+ lora_int_id,
188
187
  past_key_values,
189
188
  rotary_emb,
190
189
  )
@@ -199,6 +198,7 @@ class DecoderOnlyWrapper(nn.Module):
199
198
  query_position,
200
199
  attention_mask,
201
200
  position_ids,
201
+ lora_int_id,
202
202
  past_key_values,
203
203
  rotary_emb,
204
204
  ) = self.prepare_forward_args(*args)
@@ -214,6 +214,7 @@ class DecoderOnlyWrapper(nn.Module):
214
214
  rotary_emb=rotary_emb,
215
215
  global_block_tables=global_block_tables,
216
216
  local_block_tables=local_block_tables,
217
+ lora_int_id=lora_int_id,
217
218
  )
218
219
 
219
220
  return logit
@@ -270,6 +271,7 @@ class DecoderOnlyForCausalLM(nn.Module):
270
271
  rotary_emb: nn.Module = None,
271
272
  global_block_tables: Optional[torch.Tensor] = None,
272
273
  local_block_tables: Optional[torch.Tensor] = None,
274
+ lora_int_id: Optional[torch.Tensor] = None,
273
275
  ):
274
276
  # outputs
275
277
  hidden_states = self.model(
@@ -283,6 +285,7 @@ class DecoderOnlyForCausalLM(nn.Module):
283
285
  rotary_emb=rotary_emb,
284
286
  global_block_tables=global_block_tables,
285
287
  local_block_tables=local_block_tables,
288
+ lora_int_id=lora_int_id,
286
289
  )
287
290
 
288
291
  if "prefill" in self.phase:
@@ -394,6 +397,7 @@ class DecoderOnlyModel(nn.Module):
394
397
  rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
395
398
  global_block_tables: Optional[torch.Tensor] = None,
396
399
  local_block_tables: Optional[torch.Tensor] = None,
400
+ lora_int_id: Optional[torch.Tensor] = None,
397
401
  ):
398
402
  # retrieve input_ids and inputs_embeds
399
403
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -466,6 +470,7 @@ class DecoderOnlyModel(nn.Module):
466
470
  cos=cos,
467
471
  sin=sin,
468
472
  block_tables=local_block_tables if is_sliding else global_block_tables,
473
+ lora_int_id=lora_int_id,
469
474
  )
470
475
 
471
476
  hidden_states = self.get_last_layernorm()(hidden_states)
@@ -497,11 +502,27 @@ class DecoderOnlyLayer(nn.Module):
497
502
  phase: Current operation phase ("prefill" or "decode")
498
503
  """
499
504
 
500
- def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
505
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
501
506
  super().__init__()
502
507
  self._original_mod = layer
503
508
  self.self_attn = self_attn
504
509
  self._phase = "prefill"
510
+ self.lora_config = lora_config
511
+
512
+ # Replace target Linear modules in MLP with LoRALinear if configured
513
+ if self.lora_config:
514
+ mlp = self.get_mlp()
515
+ for proj_name in ["gate_proj", "up_proj", "down_proj"]:
516
+ if hasattr(mlp, proj_name):
517
+ original_linear = getattr(mlp, proj_name)
518
+ if isinstance(original_linear, nn.Linear):
519
+ lora_linear = LoRALinear(
520
+ original_linear=original_linear,
521
+ lora_config=self.lora_config,
522
+ projection_name=proj_name,
523
+ layer_idx=self.self_attn.layer_idx,
524
+ )
525
+ setattr(mlp, proj_name, lora_linear)
505
526
 
506
527
  @property
507
528
  def phase(self):
@@ -518,6 +539,25 @@ class DecoderOnlyLayer(nn.Module):
518
539
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
519
540
  return self._original_mod.post_attention_layernorm
520
541
 
542
+ def get_mlp(self) -> nn.Module:
543
+ return self._original_mod.mlp
544
+
545
+ def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
546
+ mlp = self.get_mlp()
547
+ if self.lora_config and lora_int_id is not None:
548
+ gate = mlp.gate_proj(hidden_states, lora_int_id)
549
+ up = mlp.up_proj(hidden_states, lora_int_id)
550
+ act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
551
+ if act_fn is None:
552
+ gate = torch.nn.functional.silu(gate)
553
+ else:
554
+ gate = act_fn(gate)
555
+ fused = gate * up
556
+ hidden_states = mlp.down_proj(fused, lora_int_id)
557
+ else:
558
+ hidden_states = mlp(hidden_states)
559
+ return hidden_states
560
+
521
561
  def forward(
522
562
  self,
523
563
  hidden_states: torch.Tensor,
@@ -527,6 +567,7 @@ class DecoderOnlyLayer(nn.Module):
527
567
  cos: Optional[torch.Tensor] = None,
528
568
  sin: Optional[torch.Tensor] = None,
529
569
  block_tables: Optional[torch.Tensor] = None,
570
+ lora_int_id: Optional[torch.Tensor] = None,
530
571
  ):
531
572
  residual = hidden_states
532
573
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
@@ -539,13 +580,14 @@ class DecoderOnlyLayer(nn.Module):
539
580
  cos=cos,
540
581
  sin=sin,
541
582
  block_tables=block_tables,
583
+ lora_int_id=lora_int_id,
542
584
  )
543
585
  hidden_states = residual + hidden_states
544
586
 
545
587
  # Fully Connected
546
588
  residual = hidden_states
547
589
  hidden_states = self.get_post_attention_layernorm()(hidden_states)
548
- hidden_states = self._original_mod.mlp(hidden_states)
590
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
549
591
  hidden_states = residual + hidden_states
550
592
 
551
593
  return hidden_states
@@ -595,10 +637,23 @@ class DecoderOnlyAttention(nn.Module):
595
637
  self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
596
638
  self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
597
639
  self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
640
+ self.lora_config = rbln_config.lora_config
598
641
 
599
642
  setattr(self, self.get_attention_name(), self.create_attention_op())
600
643
  self.__post_init__()
601
644
 
645
+ def _init_lora_weights(self):
646
+ """Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
647
+ for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
648
+ original_linear = getattr(self._original_mod, proj_name)
649
+ lora_linear = LoRALinear(
650
+ original_linear=original_linear,
651
+ lora_config=self.lora_config,
652
+ projection_name=proj_name,
653
+ layer_idx=self.layer_idx,
654
+ )
655
+ setattr(self, proj_name, lora_linear)
656
+
602
657
  def get_attention_name(self):
603
658
  if self.is_sliding:
604
659
  return "sliding_window_attention"
@@ -651,23 +706,40 @@ class DecoderOnlyAttention(nn.Module):
651
706
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
652
707
 
653
708
  def __post_init__(self):
654
- self.q_proj = self._original_mod.q_proj
655
- self.k_proj = self._original_mod.k_proj
656
- self.v_proj = self._original_mod.v_proj
657
- self.o_proj = self._original_mod.o_proj
658
-
659
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
709
+ # Initialize LoRA weights if configured, which will replace linear layers
710
+ if self.lora_config:
711
+ self._init_lora_weights()
712
+ else:
713
+ # Use original linear layers if no LoRA
714
+ self.q_proj = self._original_mod.q_proj
715
+ self.k_proj = self._original_mod.k_proj
716
+ self.v_proj = self._original_mod.v_proj
717
+ self.o_proj = self._original_mod.o_proj
718
+
719
+ def projection(
720
+ self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
721
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
660
722
  """Projects input hidden states into query, key, and value representations.
661
723
 
662
724
  Args:
663
725
  hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
726
+ lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
664
727
 
665
728
  Returns:
666
729
  Tuple of (query_states, key_states, value_states)
667
730
  """
668
- query_states = self.q_proj(hidden_states)
669
- key_states = self.k_proj(hidden_states)
670
- value_states = self.v_proj(hidden_states)
731
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
732
+ if self.lora_config:
733
+ # LoRALinear handles both base projection and LoRA in one forward pass
734
+ query_states = self.q_proj(hidden_states, lora_int_id)
735
+ key_states = self.k_proj(hidden_states, lora_int_id)
736
+ value_states = self.v_proj(hidden_states, lora_int_id)
737
+ else:
738
+ # Standard linear projection without LoRA
739
+ query_states = self.q_proj(hidden_states)
740
+ key_states = self.k_proj(hidden_states)
741
+ value_states = self.v_proj(hidden_states)
742
+
671
743
  return query_states, key_states, value_states
672
744
 
673
745
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
@@ -695,10 +767,11 @@ class DecoderOnlyAttention(nn.Module):
695
767
  cos: Optional[torch.Tensor] = None,
696
768
  sin: Optional[torch.Tensor] = None,
697
769
  block_tables: Optional[torch.Tensor] = None,
770
+ lora_int_id: Optional[torch.Tensor] = None,
698
771
  ):
699
772
  batch_size, query_length, _ = hidden_states.size()
700
773
 
701
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
774
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
702
775
 
703
776
  query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
704
777
  key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -732,7 +805,14 @@ class DecoderOnlyAttention(nn.Module):
732
805
  v_scale=v_scale,
733
806
  )
734
807
 
735
- attn_outputs = self.o_proj(attn_output)
808
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
809
+ if self.lora_config:
810
+ # LoRALinear handles both base projection and LoRA in one forward pass
811
+ attn_outputs = self.o_proj(attn_output, lora_int_id)
812
+ else:
813
+ # Standard linear projection without LoRA
814
+ attn_outputs = self.o_proj(attn_output)
815
+
736
816
  return attn_outputs
737
817
 
738
818