optimum-rbln 0.9.2a3__py3-none-any.whl → 0.9.2a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +4 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +3 -0
- optimum/rbln/modeling.py +71 -1
- optimum/rbln/transformers/__init__.py +4 -0
- optimum/rbln/transformers/modeling_generic.py +23 -1
- optimum/rbln/transformers/models/__init__.py +4 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +65 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +79 -4
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +9 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +2 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +0 -9
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/METADATA +5 -5
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/RECORD +34 -32
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from huggingface_hub import snapshot_download
|
|
6
|
+
|
|
7
|
+
from ....configuration_utils import RBLNSerializableConfigProtocol
|
|
8
|
+
from ....utils.logging import get_logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RBLNLoRAAdapterConfig(RBLNSerializableConfigProtocol):
|
|
15
|
+
"""
|
|
16
|
+
Configuration class for individual LoRA adapter settings.
|
|
17
|
+
|
|
18
|
+
This class represents a single LoRA adapter that will be compiled into the RBLN model.
|
|
19
|
+
Since RBLN NPU requires all adapters to be determined at compile time, each adapter
|
|
20
|
+
must be fully specified including its weights.
|
|
21
|
+
|
|
22
|
+
Examples:
|
|
23
|
+
```python
|
|
24
|
+
from transformers import AutoTokenizer
|
|
25
|
+
|
|
26
|
+
from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig, RBLNLoRAAdapterConfig, RBLNLoRAConfig
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
model_id = "meta-llama/Llama-3.1-8B-Instruct"
|
|
30
|
+
lora_ids = [
|
|
31
|
+
"nvidia/llama-3.1-nemoguard-8b-topic-control",
|
|
32
|
+
"reissbaker/llama-3.1-8b-abliterated-lora",
|
|
33
|
+
]
|
|
34
|
+
prompt = "What are the safety considerations for AI systems?"
|
|
35
|
+
tp_size = 4
|
|
36
|
+
|
|
37
|
+
# adapter id should be higher than 0
|
|
38
|
+
# 0 is reserved for base model
|
|
39
|
+
lora_config = RBLNLoRAConfig(
|
|
40
|
+
adapters=[
|
|
41
|
+
RBLNLoRAAdapterConfig(1, "nemoguard", lora_ids[0]),
|
|
42
|
+
RBLNLoRAAdapterConfig(2, "abliterated", lora_ids[1]),
|
|
43
|
+
],
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
model = RBLNLlamaForCausalLM.from_pretrained(
|
|
47
|
+
model_id,
|
|
48
|
+
rbln_config=RBLNLlamaForCausalLMConfig(lora_config=lora_config, tensor_parallel_size=tp_size, max_seq_len=8192),
|
|
49
|
+
torch_dtype="auto",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
54
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
prompt_template = tokenizer.apply_chat_template(
|
|
58
|
+
[
|
|
59
|
+
{"role": "system", "content": "You are a helpful assistant. Always be concise"},
|
|
60
|
+
{"role": "user", "content": prompt},
|
|
61
|
+
],
|
|
62
|
+
add_generation_prompt=True,
|
|
63
|
+
tokenize=False,
|
|
64
|
+
)
|
|
65
|
+
inputs = tokenizer([prompt_template], return_tensors="pt")
|
|
66
|
+
input_len = inputs["input_ids"].shape[-1]
|
|
67
|
+
|
|
68
|
+
for adapter_name in lora_config.adapter_names:
|
|
69
|
+
model.set_adapter(adapter_name)
|
|
70
|
+
decoder_outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
|
|
71
|
+
generated_text = tokenizer.decode(decoder_outputs[0][input_len:], skip_special_tokens=True)
|
|
72
|
+
print(generated_text + "\n")
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
lora_int_id: int,
|
|
80
|
+
lora_name: str,
|
|
81
|
+
lora_path: Union[str, Path],
|
|
82
|
+
r: Optional[int] = None,
|
|
83
|
+
lora_alpha: Optional[float] = None,
|
|
84
|
+
target_modules: Optional[List[str]] = None,
|
|
85
|
+
bias: Optional[str] = None,
|
|
86
|
+
use_rslora: Optional[bool] = None,
|
|
87
|
+
scaling_factor: Optional[float] = None,
|
|
88
|
+
):
|
|
89
|
+
"""
|
|
90
|
+
Args:
|
|
91
|
+
lora_int_id (int): Unique identifier for this LoRA adapter (e.g., 0, 1, 2).
|
|
92
|
+
This ID will be used during runtime to select which adapter to use.
|
|
93
|
+
lora_name (str): Human-readable name for this adapter (e.g., "math_tuned", "code_tuned").
|
|
94
|
+
lora_path (Union[str, Path]): Path to the LoRA adapter weights directory or file.
|
|
95
|
+
Must be accessible at compile time to load the weights.
|
|
96
|
+
r (Optional[int]): The rank of the LoRA approximation for this adapter. If None,
|
|
97
|
+
will be loaded from adapter config file.
|
|
98
|
+
lora_alpha (Optional[float]): The LoRA scaling parameter for this adapter. If None,
|
|
99
|
+
will be loaded from adapter config file.
|
|
100
|
+
target_modules (Optional[List[str]]): List of module names to apply LoRA to.
|
|
101
|
+
If None, will be loaded from adapter config file or inherit from parent RBLNLoRAConfig.
|
|
102
|
+
bias (Optional[str]): Bias handling strategy. Options: "none", "all", "lora_only".
|
|
103
|
+
If None, will be loaded from adapter config file.
|
|
104
|
+
use_rslora (Optional[bool]): Whether to use Rank-Stabilized LoRA. If None,
|
|
105
|
+
will be loaded from adapter config file.
|
|
106
|
+
scaling_factor (Optional[float]): Additional scaling factor for this adapter. Defaults to 1.0.
|
|
107
|
+
**kwargs: Additional adapter-specific arguments.
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
ValueError: If lora_int_id is None.
|
|
111
|
+
ValueError: If lora_path doesn't exist.
|
|
112
|
+
ValueError: If r is not a positive integer.
|
|
113
|
+
ValueError: If lora_alpha is not positive.
|
|
114
|
+
"""
|
|
115
|
+
if lora_int_id is None:
|
|
116
|
+
raise ValueError("lora_int_id cannot be None")
|
|
117
|
+
|
|
118
|
+
if not isinstance(lora_int_id, int):
|
|
119
|
+
raise ValueError(f"lora_int_id must be an integer, got {type(lora_int_id)}")
|
|
120
|
+
|
|
121
|
+
self.lora_int_id = lora_int_id
|
|
122
|
+
self.lora_name = lora_name
|
|
123
|
+
|
|
124
|
+
# Keep original lora_path as provided by user (for serialization)
|
|
125
|
+
self.lora_path = Path(lora_path)
|
|
126
|
+
|
|
127
|
+
# Resolve to local directory path (for actual weight loading)
|
|
128
|
+
self.local_adapter_path = self._resolve_adapter_path(self.lora_path)
|
|
129
|
+
|
|
130
|
+
# Load adapter config and use as defaults
|
|
131
|
+
adapter_config = self._load_adapter_config()
|
|
132
|
+
|
|
133
|
+
# Set values from adapter config if not explicitly provided
|
|
134
|
+
self.r = r if r is not None else adapter_config.get("r", 8)
|
|
135
|
+
self.lora_alpha = lora_alpha if lora_alpha is not None else adapter_config.get("lora_alpha", 8.0)
|
|
136
|
+
self.target_modules = (
|
|
137
|
+
target_modules if target_modules is not None else adapter_config.get("target_modules", None)
|
|
138
|
+
)
|
|
139
|
+
self.bias = bias if bias is not None else adapter_config.get("bias", "none")
|
|
140
|
+
if self.bias not in ["none"]:
|
|
141
|
+
raise NotImplementedError("bias != 'none' is not supported yet")
|
|
142
|
+
|
|
143
|
+
self.use_rslora = use_rslora if use_rslora is not None else adapter_config.get("use_rslora", False)
|
|
144
|
+
self.scaling_factor = scaling_factor if scaling_factor is not None else 1.0
|
|
145
|
+
|
|
146
|
+
# Validate the final values
|
|
147
|
+
if not isinstance(self.r, int) or self.r <= 0:
|
|
148
|
+
raise ValueError(f"r must be a positive integer, got {self.r}")
|
|
149
|
+
|
|
150
|
+
if self.lora_alpha <= 0:
|
|
151
|
+
raise ValueError(f"lora_alpha must be positive, got {self.lora_alpha}")
|
|
152
|
+
|
|
153
|
+
if self.bias not in ["none", "all", "lora_only"]:
|
|
154
|
+
raise ValueError(f"bias must be one of ['none', 'all', 'lora_only'], got {self.bias}")
|
|
155
|
+
|
|
156
|
+
def _resolve_adapter_path(self, path: Path) -> Path:
|
|
157
|
+
"""
|
|
158
|
+
Resolve the adapter path, downloading from HuggingFace Hub if necessary.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
path: Local path or HuggingFace Hub model ID
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Path object pointing to local adapter directory
|
|
165
|
+
|
|
166
|
+
Raises:
|
|
167
|
+
ValueError: If the adapter cannot be found locally or downloaded
|
|
168
|
+
"""
|
|
169
|
+
# If it's a local path and exists, return it
|
|
170
|
+
if path.exists():
|
|
171
|
+
return path
|
|
172
|
+
|
|
173
|
+
# If it's an absolute path that doesn't exist, raise error
|
|
174
|
+
if path.is_absolute():
|
|
175
|
+
raise ValueError(f"LoRA adapter path does not exist: {path.as_posix()}")
|
|
176
|
+
|
|
177
|
+
# Try to interpret as HuggingFace Hub model ID and download
|
|
178
|
+
try:
|
|
179
|
+
local_dir = snapshot_download(str(path), allow_patterns=["*.safetensors", "*.bin", "*.json"])
|
|
180
|
+
return Path(local_dir)
|
|
181
|
+
except Exception as e:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"Failed to download LoRA adapter '{path.as_posix()}' from HuggingFace Hub. "
|
|
184
|
+
f"Please check if the model ID is correct or provide a valid local path. "
|
|
185
|
+
f"Error: {e}"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def _load_adapter_config(self) -> Dict[str, Any]:
|
|
189
|
+
"""
|
|
190
|
+
Load adapter configuration from adapter_config.json file.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Dictionary containing adapter configuration
|
|
194
|
+
|
|
195
|
+
Raises:
|
|
196
|
+
ValueError: If adapter_config.json is not found or cannot be parsed
|
|
197
|
+
"""
|
|
198
|
+
config_path = self.local_adapter_path / "adapter_config.json"
|
|
199
|
+
|
|
200
|
+
if not config_path.exists():
|
|
201
|
+
logger.warning(f"No adapter_config.json found at {config_path}, using default values")
|
|
202
|
+
return {}
|
|
203
|
+
|
|
204
|
+
try:
|
|
205
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
206
|
+
adapter_config = json.load(f)
|
|
207
|
+
logger.info(f"Loaded adapter config from {config_path}")
|
|
208
|
+
return adapter_config
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logger.warning(f"Failed to load adapter config from {config_path}: {e}, using default values")
|
|
211
|
+
return {}
|
|
212
|
+
|
|
213
|
+
def _prepare_for_serialization(self) -> Dict[str, Any]:
|
|
214
|
+
config_dict = {
|
|
215
|
+
"lora_int_id": self.lora_int_id,
|
|
216
|
+
"lora_name": self.lora_name,
|
|
217
|
+
"lora_path": str(self.lora_path),
|
|
218
|
+
"r": self.r,
|
|
219
|
+
"lora_alpha": self.lora_alpha,
|
|
220
|
+
"target_modules": self.target_modules,
|
|
221
|
+
"bias": self.bias,
|
|
222
|
+
"use_rslora": self.use_rslora,
|
|
223
|
+
"scaling_factor": self.scaling_factor,
|
|
224
|
+
}
|
|
225
|
+
return config_dict
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class RBLNLoRABaseAdapterConfig(RBLNLoRAAdapterConfig):
|
|
229
|
+
"""
|
|
230
|
+
Special adapter config for the reserved base model adapter (lora_int_id = 0).
|
|
231
|
+
This adapter carries zero-effective LoRA weights by targeting no modules,
|
|
232
|
+
thereby producing no LoRA delta and yielding pure base-model behavior.
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
lora_int_id: int = 0,
|
|
238
|
+
lora_name: str = "base",
|
|
239
|
+
lora_path: Union[str, Path] = "__reserved_base__",
|
|
240
|
+
r: Optional[int] = 1,
|
|
241
|
+
lora_alpha: Optional[float] = 1.0,
|
|
242
|
+
target_modules: Optional[List[str]] = None,
|
|
243
|
+
bias: Optional[str] = "none",
|
|
244
|
+
use_rslora: Optional[bool] = False,
|
|
245
|
+
scaling_factor: Optional[float] = 1.0,
|
|
246
|
+
):
|
|
247
|
+
if lora_int_id != 0:
|
|
248
|
+
raise ValueError("RBLNLoRABaseAdapterConfig must have lora_int_id=0")
|
|
249
|
+
|
|
250
|
+
self.lora_int_id = 0
|
|
251
|
+
self.lora_name = lora_name
|
|
252
|
+
# Keep original lora_path for serialization purposes but do not resolve it.
|
|
253
|
+
self.lora_path = Path(str(lora_path))
|
|
254
|
+
self.local_adapter_path = None
|
|
255
|
+
|
|
256
|
+
# Set minimal defaults; target_modules empty disables LoRA on all projections
|
|
257
|
+
self.r = 1 if r is None else r
|
|
258
|
+
self.lora_alpha = 1.0 if lora_alpha is None else lora_alpha
|
|
259
|
+
self.target_modules = []
|
|
260
|
+
self.bias = "none"
|
|
261
|
+
self.use_rslora = False
|
|
262
|
+
self.scaling_factor = 1.0
|
|
263
|
+
|
|
264
|
+
# Validate minimal settings
|
|
265
|
+
if not isinstance(self.r, int) or self.r <= 0:
|
|
266
|
+
raise ValueError(f"r must be a positive integer, got {self.r}")
|
|
267
|
+
if self.lora_alpha <= 0:
|
|
268
|
+
raise ValueError(f"lora_alpha must be positive, got {self.lora_alpha}")
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class RBLNLoRAConfig(RBLNSerializableConfigProtocol):
|
|
272
|
+
"""
|
|
273
|
+
Configuration class for multi-LoRA support in RBLN decoder-only models.
|
|
274
|
+
|
|
275
|
+
This class manages all LoRA adapters that will be compiled into the RBLN model.
|
|
276
|
+
Since RBLN NPU requires all adapters to be determined at compile time, this
|
|
277
|
+
configuration must specify all adapters upfront with their weights.
|
|
278
|
+
|
|
279
|
+
Key constraints for RBLN multi-LoRA:
|
|
280
|
+
1. All LoRA adapters must be specified at compile time
|
|
281
|
+
2. Adapter weights must be available during compilation
|
|
282
|
+
3. The number of adapters is fixed after compilation
|
|
283
|
+
4. Runtime can only switch between pre-compiled adapters
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
def __init__(
|
|
287
|
+
self, adapters: List[Union[Dict[str, Any], RBLNLoRAAdapterConfig]], max_lora_rank: Optional[int] = None
|
|
288
|
+
):
|
|
289
|
+
"""
|
|
290
|
+
Args:
|
|
291
|
+
adapters (List[Union[Dict[str, Any], RBLNLoRAAdapterConfig]]): List of LoRA adapters
|
|
292
|
+
to be compiled into the model. Each adapter must be fully specified with weights
|
|
293
|
+
accessible at compile time.
|
|
294
|
+
max_lora_rank (Optional[int]): Maximum rank across all adapters. If None, automatically
|
|
295
|
+
determined from the provided adapters. Used for memory allocation optimization.
|
|
296
|
+
|
|
297
|
+
Raises:
|
|
298
|
+
ValueError: If adapters list is empty.
|
|
299
|
+
ValueError: If adapter IDs are not unique.
|
|
300
|
+
ValueError: If any adapter path doesn't exist.
|
|
301
|
+
"""
|
|
302
|
+
if not adapters:
|
|
303
|
+
raise ValueError("adapters list cannot be empty")
|
|
304
|
+
|
|
305
|
+
# Convert dict adapters to RBLNLoRAAdapterConfig objects
|
|
306
|
+
self.adapters: List[RBLNLoRAAdapterConfig] = []
|
|
307
|
+
for adapter in adapters:
|
|
308
|
+
if isinstance(adapter, dict):
|
|
309
|
+
self.adapters.append(RBLNLoRAAdapterConfig(**adapter))
|
|
310
|
+
elif isinstance(adapter, RBLNLoRAAdapterConfig):
|
|
311
|
+
self.adapters.append(adapter)
|
|
312
|
+
else:
|
|
313
|
+
raise ValueError(f"Invalid adapter type: {type(adapter)}")
|
|
314
|
+
|
|
315
|
+
# Disallow user-provided adapter with id 0: it's reserved for base model
|
|
316
|
+
if any(ad.lora_int_id == 0 for ad in self.adapters):
|
|
317
|
+
raise ValueError(
|
|
318
|
+
"lora_int_id=0 is reserved for base model and cannot be provided. Please renumber your adapters to start from 1."
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Inject a reserved zero-weight adapter for base model at id=0
|
|
322
|
+
base_adapter = RBLNLoRABaseAdapterConfig()
|
|
323
|
+
self.adapters.insert(0, base_adapter)
|
|
324
|
+
|
|
325
|
+
# Sort adapters by ID to make IDs align with indices
|
|
326
|
+
self.adapters.sort(key=lambda a: a.lora_int_id)
|
|
327
|
+
|
|
328
|
+
# Validate unique and contiguous adapter IDs starting from 0
|
|
329
|
+
adapter_ids = [adapter.lora_int_id for adapter in self.adapters]
|
|
330
|
+
if len(adapter_ids) != len(set(adapter_ids)):
|
|
331
|
+
raise ValueError("All adapter IDs must be unique")
|
|
332
|
+
expected_ids = list(range(len(self.adapters)))
|
|
333
|
+
if adapter_ids != expected_ids:
|
|
334
|
+
raise ValueError(
|
|
335
|
+
f"Adapter IDs must be contiguous and start from 0. Found {adapter_ids}, expected {expected_ids}."
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Calculate max_lora_rank if not provided
|
|
339
|
+
if max_lora_rank is None:
|
|
340
|
+
self.max_lora_rank = max(adapter.r for adapter in self.adapters)
|
|
341
|
+
else:
|
|
342
|
+
self.max_lora_rank = max_lora_rank
|
|
343
|
+
# Validate that max_lora_rank is sufficient
|
|
344
|
+
actual_max_rank = max(adapter.r for adapter in self.adapters)
|
|
345
|
+
if self.max_lora_rank < actual_max_rank:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
f"max_lora_rank ({self.max_lora_rank}) must be >= actual max rank ({actual_max_rank})"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def num_adapters(self) -> int:
|
|
352
|
+
return len(self.adapters)
|
|
353
|
+
|
|
354
|
+
@property
|
|
355
|
+
def adapter_ids(self) -> List[int]:
|
|
356
|
+
return [adapter.lora_int_id for adapter in self.adapters]
|
|
357
|
+
|
|
358
|
+
@property
|
|
359
|
+
def adapter_names(self) -> List[str]:
|
|
360
|
+
return [adapter.lora_name for adapter in self.adapters]
|
|
361
|
+
|
|
362
|
+
def get_adapter_by_id(self, lora_int_id: int) -> Optional[RBLNLoRAAdapterConfig]:
|
|
363
|
+
for adapter in self.adapters:
|
|
364
|
+
if adapter.lora_int_id == lora_int_id:
|
|
365
|
+
return adapter
|
|
366
|
+
return None
|
|
367
|
+
|
|
368
|
+
def get_adapter_by_name(self, lora_name: str) -> Optional[RBLNLoRAAdapterConfig]:
|
|
369
|
+
for adapter in self.adapters:
|
|
370
|
+
if adapter.lora_name == lora_name:
|
|
371
|
+
return adapter
|
|
372
|
+
return None
|
|
373
|
+
|
|
374
|
+
def validate_adapter_weights(self) -> Dict[int, bool]:
|
|
375
|
+
validation_results = {}
|
|
376
|
+
for adapter in self.adapters:
|
|
377
|
+
try:
|
|
378
|
+
# The reserved base adapter (id=0) always validates to True
|
|
379
|
+
if adapter.lora_int_id == 0:
|
|
380
|
+
validation_results[adapter.lora_int_id] = True
|
|
381
|
+
continue
|
|
382
|
+
# Check if adapter path exists and contains expected files
|
|
383
|
+
adapter_path = adapter.local_adapter_path
|
|
384
|
+
if adapter_path is not None and adapter_path.is_file():
|
|
385
|
+
# Single file adapter (e.g., safetensors)
|
|
386
|
+
validation_results[adapter.lora_int_id] = adapter_path.exists()
|
|
387
|
+
else:
|
|
388
|
+
# Directory adapter - check for common LoRA files
|
|
389
|
+
expected_files = ["adapter_model.safetensors", "adapter_config.json"]
|
|
390
|
+
alternative_files = ["pytorch_model.bin", "adapter_model.bin"]
|
|
391
|
+
|
|
392
|
+
has_weights = adapter_path is not None and any(
|
|
393
|
+
(adapter_path / f).exists() for f in expected_files + alternative_files
|
|
394
|
+
)
|
|
395
|
+
has_config = adapter_path is not None and (adapter_path / "adapter_config.json").exists()
|
|
396
|
+
|
|
397
|
+
validation_results[adapter.lora_int_id] = has_weights and has_config
|
|
398
|
+
except Exception as e:
|
|
399
|
+
logger.warning(f"Failed to validate adapter {adapter.lora_int_id}: {e}")
|
|
400
|
+
validation_results[adapter.lora_int_id] = False
|
|
401
|
+
|
|
402
|
+
return validation_results
|
|
403
|
+
|
|
404
|
+
def _prepare_for_serialization(self) -> Dict[str, Any]:
|
|
405
|
+
# Do not serialize the reserved base adapter (id=0)
|
|
406
|
+
serializable_adapters = [adapter for adapter in self.adapters if adapter.lora_int_id != 0]
|
|
407
|
+
serializable_map = {
|
|
408
|
+
"adapters": [adapter._prepare_for_serialization() for adapter in serializable_adapters],
|
|
409
|
+
"max_lora_rank": self.max_lora_rank,
|
|
410
|
+
}
|
|
411
|
+
return serializable_map
|
|
@@ -22,6 +22,8 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
|
22
22
|
from ....utils import logging
|
|
23
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
24
|
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
25
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
26
|
+
from .lora_architecture import LoRALinear
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
if TYPE_CHECKING:
|
|
@@ -52,12 +54,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
52
54
|
|
|
53
55
|
_use_learned_pos_emb = False
|
|
54
56
|
|
|
55
|
-
def __init__(
|
|
56
|
-
self,
|
|
57
|
-
model: PreTrainedModel,
|
|
58
|
-
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
59
|
-
use_rotary_emb: bool,
|
|
60
|
-
):
|
|
57
|
+
def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
|
|
61
58
|
super().__init__()
|
|
62
59
|
self.quantization = rbln_config.quantization
|
|
63
60
|
self.config = model.config
|
|
@@ -114,7 +111,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
114
111
|
new_self_attn = self.get_rbln_attn_class()(
|
|
115
112
|
self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
|
|
116
113
|
)
|
|
117
|
-
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
114
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
|
|
118
115
|
new_layers.append(new_layer)
|
|
119
116
|
|
|
120
117
|
new_model = self.get_rbln_model_class()(
|
|
@@ -154,6 +151,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
154
151
|
)
|
|
155
152
|
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
156
153
|
position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
|
|
154
|
+
lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
|
|
157
155
|
past_key_values = args
|
|
158
156
|
|
|
159
157
|
if len(past_key_values) != 2 * self.num_hidden_layers:
|
|
@@ -185,6 +183,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
185
183
|
query_position,
|
|
186
184
|
attention_mask,
|
|
187
185
|
position_ids,
|
|
186
|
+
lora_int_id,
|
|
188
187
|
past_key_values,
|
|
189
188
|
rotary_emb,
|
|
190
189
|
)
|
|
@@ -199,6 +198,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
199
198
|
query_position,
|
|
200
199
|
attention_mask,
|
|
201
200
|
position_ids,
|
|
201
|
+
lora_int_id,
|
|
202
202
|
past_key_values,
|
|
203
203
|
rotary_emb,
|
|
204
204
|
) = self.prepare_forward_args(*args)
|
|
@@ -214,6 +214,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
214
214
|
rotary_emb=rotary_emb,
|
|
215
215
|
global_block_tables=global_block_tables,
|
|
216
216
|
local_block_tables=local_block_tables,
|
|
217
|
+
lora_int_id=lora_int_id,
|
|
217
218
|
)
|
|
218
219
|
|
|
219
220
|
return logit
|
|
@@ -270,6 +271,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
270
271
|
rotary_emb: nn.Module = None,
|
|
271
272
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
272
273
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
274
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
273
275
|
):
|
|
274
276
|
# outputs
|
|
275
277
|
hidden_states = self.model(
|
|
@@ -283,6 +285,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
283
285
|
rotary_emb=rotary_emb,
|
|
284
286
|
global_block_tables=global_block_tables,
|
|
285
287
|
local_block_tables=local_block_tables,
|
|
288
|
+
lora_int_id=lora_int_id,
|
|
286
289
|
)
|
|
287
290
|
|
|
288
291
|
if "prefill" in self.phase:
|
|
@@ -394,6 +397,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
394
397
|
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
|
395
398
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
396
399
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
400
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
397
401
|
):
|
|
398
402
|
# retrieve input_ids and inputs_embeds
|
|
399
403
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -466,6 +470,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
466
470
|
cos=cos,
|
|
467
471
|
sin=sin,
|
|
468
472
|
block_tables=local_block_tables if is_sliding else global_block_tables,
|
|
473
|
+
lora_int_id=lora_int_id,
|
|
469
474
|
)
|
|
470
475
|
|
|
471
476
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
@@ -497,11 +502,27 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
497
502
|
phase: Current operation phase ("prefill" or "decode")
|
|
498
503
|
"""
|
|
499
504
|
|
|
500
|
-
def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
|
|
505
|
+
def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
|
|
501
506
|
super().__init__()
|
|
502
507
|
self._original_mod = layer
|
|
503
508
|
self.self_attn = self_attn
|
|
504
509
|
self._phase = "prefill"
|
|
510
|
+
self.lora_config = lora_config
|
|
511
|
+
|
|
512
|
+
# Replace target Linear modules in MLP with LoRALinear if configured
|
|
513
|
+
if self.lora_config:
|
|
514
|
+
mlp = self.get_mlp()
|
|
515
|
+
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
|
|
516
|
+
if hasattr(mlp, proj_name):
|
|
517
|
+
original_linear = getattr(mlp, proj_name)
|
|
518
|
+
if isinstance(original_linear, nn.Linear):
|
|
519
|
+
lora_linear = LoRALinear(
|
|
520
|
+
original_linear=original_linear,
|
|
521
|
+
lora_config=self.lora_config,
|
|
522
|
+
projection_name=proj_name,
|
|
523
|
+
layer_idx=self.self_attn.layer_idx,
|
|
524
|
+
)
|
|
525
|
+
setattr(mlp, proj_name, lora_linear)
|
|
505
526
|
|
|
506
527
|
@property
|
|
507
528
|
def phase(self):
|
|
@@ -518,6 +539,25 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
518
539
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
519
540
|
return self._original_mod.post_attention_layernorm
|
|
520
541
|
|
|
542
|
+
def get_mlp(self) -> nn.Module:
|
|
543
|
+
return self._original_mod.mlp
|
|
544
|
+
|
|
545
|
+
def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
546
|
+
mlp = self.get_mlp()
|
|
547
|
+
if self.lora_config and lora_int_id is not None:
|
|
548
|
+
gate = mlp.gate_proj(hidden_states, lora_int_id)
|
|
549
|
+
up = mlp.up_proj(hidden_states, lora_int_id)
|
|
550
|
+
act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
|
|
551
|
+
if act_fn is None:
|
|
552
|
+
gate = torch.nn.functional.silu(gate)
|
|
553
|
+
else:
|
|
554
|
+
gate = act_fn(gate)
|
|
555
|
+
fused = gate * up
|
|
556
|
+
hidden_states = mlp.down_proj(fused, lora_int_id)
|
|
557
|
+
else:
|
|
558
|
+
hidden_states = mlp(hidden_states)
|
|
559
|
+
return hidden_states
|
|
560
|
+
|
|
521
561
|
def forward(
|
|
522
562
|
self,
|
|
523
563
|
hidden_states: torch.Tensor,
|
|
@@ -527,6 +567,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
527
567
|
cos: Optional[torch.Tensor] = None,
|
|
528
568
|
sin: Optional[torch.Tensor] = None,
|
|
529
569
|
block_tables: Optional[torch.Tensor] = None,
|
|
570
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
530
571
|
):
|
|
531
572
|
residual = hidden_states
|
|
532
573
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
@@ -539,13 +580,14 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
539
580
|
cos=cos,
|
|
540
581
|
sin=sin,
|
|
541
582
|
block_tables=block_tables,
|
|
583
|
+
lora_int_id=lora_int_id,
|
|
542
584
|
)
|
|
543
585
|
hidden_states = residual + hidden_states
|
|
544
586
|
|
|
545
587
|
# Fully Connected
|
|
546
588
|
residual = hidden_states
|
|
547
589
|
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
548
|
-
hidden_states = self.
|
|
590
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
549
591
|
hidden_states = residual + hidden_states
|
|
550
592
|
|
|
551
593
|
return hidden_states
|
|
@@ -595,10 +637,23 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
595
637
|
self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
|
|
596
638
|
self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
|
|
597
639
|
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
640
|
+
self.lora_config = rbln_config.lora_config
|
|
598
641
|
|
|
599
642
|
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
600
643
|
self.__post_init__()
|
|
601
644
|
|
|
645
|
+
def _init_lora_weights(self):
|
|
646
|
+
"""Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
|
|
647
|
+
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
648
|
+
original_linear = getattr(self._original_mod, proj_name)
|
|
649
|
+
lora_linear = LoRALinear(
|
|
650
|
+
original_linear=original_linear,
|
|
651
|
+
lora_config=self.lora_config,
|
|
652
|
+
projection_name=proj_name,
|
|
653
|
+
layer_idx=self.layer_idx,
|
|
654
|
+
)
|
|
655
|
+
setattr(self, proj_name, lora_linear)
|
|
656
|
+
|
|
602
657
|
def get_attention_name(self):
|
|
603
658
|
if self.is_sliding:
|
|
604
659
|
return "sliding_window_attention"
|
|
@@ -651,23 +706,40 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
651
706
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
652
707
|
|
|
653
708
|
def __post_init__(self):
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
709
|
+
# Initialize LoRA weights if configured, which will replace linear layers
|
|
710
|
+
if self.lora_config:
|
|
711
|
+
self._init_lora_weights()
|
|
712
|
+
else:
|
|
713
|
+
# Use original linear layers if no LoRA
|
|
714
|
+
self.q_proj = self._original_mod.q_proj
|
|
715
|
+
self.k_proj = self._original_mod.k_proj
|
|
716
|
+
self.v_proj = self._original_mod.v_proj
|
|
717
|
+
self.o_proj = self._original_mod.o_proj
|
|
718
|
+
|
|
719
|
+
def projection(
|
|
720
|
+
self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
|
|
721
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
660
722
|
"""Projects input hidden states into query, key, and value representations.
|
|
661
723
|
|
|
662
724
|
Args:
|
|
663
725
|
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
|
726
|
+
lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
|
|
664
727
|
|
|
665
728
|
Returns:
|
|
666
729
|
Tuple of (query_states, key_states, value_states)
|
|
667
730
|
"""
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
731
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
732
|
+
if self.lora_config:
|
|
733
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
734
|
+
query_states = self.q_proj(hidden_states, lora_int_id)
|
|
735
|
+
key_states = self.k_proj(hidden_states, lora_int_id)
|
|
736
|
+
value_states = self.v_proj(hidden_states, lora_int_id)
|
|
737
|
+
else:
|
|
738
|
+
# Standard linear projection without LoRA
|
|
739
|
+
query_states = self.q_proj(hidden_states)
|
|
740
|
+
key_states = self.k_proj(hidden_states)
|
|
741
|
+
value_states = self.v_proj(hidden_states)
|
|
742
|
+
|
|
671
743
|
return query_states, key_states, value_states
|
|
672
744
|
|
|
673
745
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
@@ -695,10 +767,11 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
695
767
|
cos: Optional[torch.Tensor] = None,
|
|
696
768
|
sin: Optional[torch.Tensor] = None,
|
|
697
769
|
block_tables: Optional[torch.Tensor] = None,
|
|
770
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
698
771
|
):
|
|
699
772
|
batch_size, query_length, _ = hidden_states.size()
|
|
700
773
|
|
|
701
|
-
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
774
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
|
|
702
775
|
|
|
703
776
|
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
704
777
|
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
@@ -732,7 +805,14 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
732
805
|
v_scale=v_scale,
|
|
733
806
|
)
|
|
734
807
|
|
|
735
|
-
|
|
808
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
809
|
+
if self.lora_config:
|
|
810
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
811
|
+
attn_outputs = self.o_proj(attn_output, lora_int_id)
|
|
812
|
+
else:
|
|
813
|
+
# Standard linear projection without LoRA
|
|
814
|
+
attn_outputs = self.o_proj(attn_output)
|
|
815
|
+
|
|
736
816
|
return attn_outputs
|
|
737
817
|
|
|
738
818
|
|