optimum-rbln 0.9.2a3__py3-none-any.whl → 0.9.2a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +4 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +3 -0
- optimum/rbln/modeling.py +71 -1
- optimum/rbln/transformers/__init__.py +4 -0
- optimum/rbln/transformers/modeling_generic.py +23 -1
- optimum/rbln/transformers/models/__init__.py +4 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +65 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +79 -4
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +9 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +2 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +0 -9
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/METADATA +5 -5
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/RECORD +34 -32
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/licenses/LICENSE +0 -0
|
@@ -187,6 +187,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
187
187
|
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
188
188
|
)
|
|
189
189
|
|
|
190
|
+
self.lora_int_ids = None
|
|
191
|
+
|
|
190
192
|
def inputs_embeddings_if_needed(
|
|
191
193
|
self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
|
|
192
194
|
):
|
|
@@ -210,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
210
212
|
position_ids: Optional[torch.Tensor] = None,
|
|
211
213
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
212
214
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
215
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
213
216
|
):
|
|
214
217
|
inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
|
|
215
218
|
block_tables, local_block_tables, is_external_block_tables = (
|
|
@@ -233,6 +236,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
233
236
|
position_embed=position_embed,
|
|
234
237
|
position_ids=position_ids,
|
|
235
238
|
local_block_tables=local_block_tables,
|
|
239
|
+
lora_int_ids=lora_int_ids,
|
|
236
240
|
)
|
|
237
241
|
else:
|
|
238
242
|
return self.prefill_forward(
|
|
@@ -245,6 +249,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
245
249
|
position_embed=position_embed,
|
|
246
250
|
token_type_ids=token_type_ids,
|
|
247
251
|
local_block_tables=local_block_tables,
|
|
252
|
+
lora_int_ids=lora_int_ids,
|
|
248
253
|
)
|
|
249
254
|
|
|
250
255
|
def decode_forward(
|
|
@@ -257,7 +262,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
257
262
|
position_embed: Optional[torch.Tensor] = None,
|
|
258
263
|
position_ids: Optional[torch.Tensor] = None,
|
|
259
264
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
265
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
260
266
|
) -> torch.FloatTensor:
|
|
267
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
268
|
+
if self.lora_int_ids is None:
|
|
269
|
+
raise ValueError(
|
|
270
|
+
"lora_int_id is required when using LoRA. "
|
|
271
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
lora_int_ids = self.lora_int_ids
|
|
275
|
+
|
|
276
|
+
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
277
|
+
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
278
|
+
|
|
261
279
|
if self.batch_size != cache_position.shape[0]:
|
|
262
280
|
raise RuntimeError(
|
|
263
281
|
f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
|
|
@@ -287,6 +305,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
287
305
|
position_embed,
|
|
288
306
|
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
289
307
|
position_ids if self.rbln_config.use_position_ids else None,
|
|
308
|
+
lora_int_ids if self.rbln_config.use_lora else None,
|
|
290
309
|
)
|
|
291
310
|
|
|
292
311
|
return RBLNDecoderOnlyOutput(logits=logits)
|
|
@@ -369,12 +388,25 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
369
388
|
position_embed: Optional[torch.Tensor] = None,
|
|
370
389
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
371
390
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
391
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
372
392
|
) -> torch.FloatTensor:
|
|
373
393
|
"""
|
|
374
394
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
375
395
|
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
376
396
|
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
377
397
|
"""
|
|
398
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
399
|
+
if self.lora_int_ids is None:
|
|
400
|
+
raise ValueError(
|
|
401
|
+
"lora_int_id is required when using LoRA. "
|
|
402
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
if batch_idx is not None:
|
|
406
|
+
lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
|
|
407
|
+
else:
|
|
408
|
+
lora_int_ids = self.lora_int_ids.clone()
|
|
409
|
+
|
|
378
410
|
(
|
|
379
411
|
inputs,
|
|
380
412
|
cache_position,
|
|
@@ -426,6 +458,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
426
458
|
query_position,
|
|
427
459
|
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
428
460
|
position_ids_chunk,
|
|
461
|
+
lora_int_ids if self.rbln_config.use_lora else None,
|
|
429
462
|
out=self.out_buffers,
|
|
430
463
|
)
|
|
431
464
|
output_logits.append(output_logit)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import safetensors.torch
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from ....utils import logging
|
|
10
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger = logging.get_logger()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LoRALinear(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
A linear layer that supports multiple LoRA adapters compiled at static time.
|
|
19
|
+
|
|
20
|
+
This class replaces the original linear layer and handles both base weights
|
|
21
|
+
and multiple LoRA adapters in a single forward pass using custom ops.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
original_linear: nn.Linear,
|
|
27
|
+
lora_config: RBLNLoRAConfig,
|
|
28
|
+
projection_name: str = "",
|
|
29
|
+
layer_idx: int = 0,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
original_linear: The original linear layer to be replaced
|
|
34
|
+
lora_config: LoRA configuration containing all adapters
|
|
35
|
+
projection_name: Name of the projection (e.g., "q_proj", "k_proj")
|
|
36
|
+
layer_idx: Layer index for loading the correct LoRA weights
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self.in_features = original_linear.in_features
|
|
41
|
+
self.out_features = original_linear.out_features
|
|
42
|
+
self.projection_name = projection_name
|
|
43
|
+
self.layer_idx = layer_idx
|
|
44
|
+
self.lora_config = lora_config
|
|
45
|
+
|
|
46
|
+
# Store original linear weights and bias directly without cloning
|
|
47
|
+
self.register_buffer("weight", original_linear.weight.data)
|
|
48
|
+
if original_linear.bias is not None:
|
|
49
|
+
self.register_buffer("bias", original_linear.bias.data)
|
|
50
|
+
else:
|
|
51
|
+
self.bias = None
|
|
52
|
+
|
|
53
|
+
# Initialize LoRA weights
|
|
54
|
+
self._init_lora_weights()
|
|
55
|
+
|
|
56
|
+
def _should_apply_lora(self) -> bool:
|
|
57
|
+
"""Check if this projection should have LoRA applied."""
|
|
58
|
+
# Check if any adapter targets this projection
|
|
59
|
+
return any(self.projection_name in adapter.target_modules for adapter in self.lora_config.adapters)
|
|
60
|
+
|
|
61
|
+
def _load_adapter_weights(self, adapter_path: Path):
|
|
62
|
+
"""
|
|
63
|
+
Load adapter weights from local directory.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
adapter_path: Path to local directory containing adapter weights
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary containing adapter weights
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
FileNotFoundError: If no adapter weights are found in the directory
|
|
73
|
+
"""
|
|
74
|
+
if not adapter_path.is_dir():
|
|
75
|
+
raise ValueError(f"Adapter path must be a directory, got: {adapter_path}")
|
|
76
|
+
|
|
77
|
+
# Try to load weights in order of preference
|
|
78
|
+
weight_files = [
|
|
79
|
+
("adapter_model.safetensors", lambda p: safetensors.torch.load_file(p)),
|
|
80
|
+
("adapter_model.bin", lambda p: torch.load(p, map_location="cpu")),
|
|
81
|
+
("pytorch_model.bin", lambda p: torch.load(p, map_location="cpu")),
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
for filename, load_fn in weight_files:
|
|
85
|
+
weight_path = adapter_path / filename
|
|
86
|
+
if weight_path.exists():
|
|
87
|
+
return load_fn(weight_path)
|
|
88
|
+
|
|
89
|
+
raise FileNotFoundError(
|
|
90
|
+
f"No adapter weights found in {adapter_path}. "
|
|
91
|
+
f"Expected one of: {', '.join(filename for filename, _ in weight_files)}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _init_lora_weights(self):
|
|
95
|
+
"""Initialize LoRA adapter weights by loading and stacking them."""
|
|
96
|
+
|
|
97
|
+
lora_a_weights = []
|
|
98
|
+
lora_b_weights = []
|
|
99
|
+
|
|
100
|
+
for adapter in self.lora_config.adapters:
|
|
101
|
+
if self.projection_name not in adapter.target_modules:
|
|
102
|
+
# Create zero weights for adapters that don't target this projection
|
|
103
|
+
lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
|
|
104
|
+
lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
adapter_weights = self._load_adapter_weights(adapter.local_adapter_path)
|
|
108
|
+
|
|
109
|
+
# Determine module type from projection name
|
|
110
|
+
attn_projs = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
|
111
|
+
mlp_projs = {"gate_proj", "up_proj", "down_proj"}
|
|
112
|
+
if self.projection_name in attn_projs:
|
|
113
|
+
module_type = "self_attn"
|
|
114
|
+
elif self.projection_name in mlp_projs:
|
|
115
|
+
module_type = "mlp"
|
|
116
|
+
else:
|
|
117
|
+
module_type = "self_attn"
|
|
118
|
+
|
|
119
|
+
layer_key = f"base_model.model.model.layers.{self.layer_idx}.{module_type}.{self.projection_name}"
|
|
120
|
+
lora_a_key = f"{layer_key}.lora_A.weight"
|
|
121
|
+
lora_b_key = f"{layer_key}.lora_B.weight"
|
|
122
|
+
|
|
123
|
+
if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
|
|
124
|
+
# Calculate scaling factor and fold it into lora_b
|
|
125
|
+
scaling = adapter.lora_alpha / adapter.r
|
|
126
|
+
if adapter.use_rslora:
|
|
127
|
+
scaling = scaling / math.sqrt(adapter.r)
|
|
128
|
+
scaling = scaling * adapter.scaling_factor
|
|
129
|
+
|
|
130
|
+
lora_a_weights.append(adapter_weights[lora_a_key])
|
|
131
|
+
# scaling is pre-applied to lora_b_weights
|
|
132
|
+
lora_b_weights.append(adapter_weights[lora_b_key] * scaling)
|
|
133
|
+
else:
|
|
134
|
+
logger.warning(f"No LoRA weights found for {lora_a_key} or {lora_b_key}")
|
|
135
|
+
lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
|
|
136
|
+
lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
|
|
137
|
+
|
|
138
|
+
# Stack weights along adapter dimension
|
|
139
|
+
max_rank = self.lora_config.max_lora_rank
|
|
140
|
+
|
|
141
|
+
# Pad smaller ranks to max_rank
|
|
142
|
+
padded_lora_a = []
|
|
143
|
+
padded_lora_b = []
|
|
144
|
+
|
|
145
|
+
for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
|
|
146
|
+
current_rank = lora_a.shape[0]
|
|
147
|
+
if current_rank < max_rank:
|
|
148
|
+
# Pad with zeros
|
|
149
|
+
padded_a = torch.zeros(max_rank, self.in_features)
|
|
150
|
+
padded_b = torch.zeros(self.out_features, max_rank)
|
|
151
|
+
padded_a[:current_rank] = lora_a
|
|
152
|
+
padded_b[:, :current_rank] = lora_b
|
|
153
|
+
padded_lora_a.append(padded_a)
|
|
154
|
+
padded_lora_b.append(padded_b)
|
|
155
|
+
else:
|
|
156
|
+
padded_lora_a.append(lora_a)
|
|
157
|
+
padded_lora_b.append(lora_b)
|
|
158
|
+
|
|
159
|
+
lora_a_transposed = [lora_a.transpose(0, 1) for lora_a in padded_lora_a] # [in_features, rank]
|
|
160
|
+
lora_b_transposed = [lora_b.transpose(0, 1) for lora_b in padded_lora_b] # [rank, out_features]
|
|
161
|
+
|
|
162
|
+
self.register_buffer(
|
|
163
|
+
"lora_a_weights", torch.stack(lora_a_transposed, dim=0).to(self.weight.dtype)
|
|
164
|
+
) # [num_adapters, in_features, rank]
|
|
165
|
+
self.register_buffer(
|
|
166
|
+
"lora_b_weights", torch.stack(lora_b_transposed, dim=0).to(self.weight.dtype)
|
|
167
|
+
) # [num_adapters, rank, out_features]
|
|
168
|
+
|
|
169
|
+
def forward(self, x: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
170
|
+
"""
|
|
171
|
+
Forward pass that combines base linear transformation with LoRA.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
x: Input tensor [batch_size, seq_len, in_features]
|
|
175
|
+
lora_int_id: Adapter ID tensor [batch_size] indicating which adapter to use
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Output tensor [batch_size, seq_len, out_features]
|
|
179
|
+
"""
|
|
180
|
+
# Base linear transformation
|
|
181
|
+
output = torch.nn.functional.linear(x, self.weight, self.bias)
|
|
182
|
+
|
|
183
|
+
# Apply LoRA if enabled and adapter ID is provided
|
|
184
|
+
if self._should_apply_lora() and lora_int_id is not None:
|
|
185
|
+
# Gather LoRA weights for each batch item
|
|
186
|
+
# lora_int_id: [batch_size] -> use as indices to select weights
|
|
187
|
+
selected_lora_a = self.lora_a_weights[lora_int_id] # [batch_size, in_features, rank]
|
|
188
|
+
selected_lora_b = self.lora_b_weights[lora_int_id] # [batch_size, rank, out_features]
|
|
189
|
+
|
|
190
|
+
# Batched matrix multiplication for LoRA computation
|
|
191
|
+
# x: [batch_size, seq_len, in_features]
|
|
192
|
+
# selected_lora_a: [batch_size, in_features, rank] (already transposed)
|
|
193
|
+
# selected_lora_b: [batch_size, rank, out_features] (already transposed)
|
|
194
|
+
|
|
195
|
+
# First matmul: x @ lora_a -> [batch_size, seq_len, rank]
|
|
196
|
+
temp = torch.bmm(x, selected_lora_a)
|
|
197
|
+
|
|
198
|
+
# Second matmul: temp @ lora_b -> [batch_size, seq_len, out_features]
|
|
199
|
+
lora_delta = torch.bmm(temp, selected_lora_b)
|
|
200
|
+
|
|
201
|
+
# Add LoRA delta to base output
|
|
202
|
+
output = output + lora_delta
|
|
203
|
+
|
|
204
|
+
return output
|
|
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
|
19
19
|
import rebel
|
|
20
20
|
import torch
|
|
21
21
|
from rebel.compile_context import CompileContext
|
|
22
|
-
from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
22
|
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
23
23
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
24
24
|
from transformers.modeling_utils import no_init_weights
|
|
25
25
|
|
|
@@ -317,12 +317,27 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
317
317
|
|
|
318
318
|
@classmethod
|
|
319
319
|
def get_pytorch_model(
|
|
320
|
-
cls,
|
|
320
|
+
cls,
|
|
321
|
+
model_id: str,
|
|
322
|
+
*args,
|
|
323
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
324
|
+
num_hidden_layers: Optional[int] = None,
|
|
325
|
+
**kwargs,
|
|
321
326
|
) -> PreTrainedModel:
|
|
322
327
|
if rbln_config and rbln_config.quantization:
|
|
323
|
-
model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
|
|
328
|
+
model = cls.get_quantized_model(model_id, *args, rbln_config=rbln_config, **kwargs)
|
|
324
329
|
else:
|
|
325
|
-
|
|
330
|
+
if num_hidden_layers is not None:
|
|
331
|
+
trust_remote_code = kwargs.get("trust_remote_code", None)
|
|
332
|
+
config, kwargs = AutoConfig.from_pretrained(
|
|
333
|
+
model_id, return_unused_kwargs=True, num_hidden_layers=num_hidden_layers, **kwargs
|
|
334
|
+
)
|
|
335
|
+
if hasattr(config, "layer_types"):
|
|
336
|
+
config.layer_types = config.layer_types[:num_hidden_layers]
|
|
337
|
+
kwargs["config"] = config
|
|
338
|
+
kwargs["trust_remote_code"] = trust_remote_code
|
|
339
|
+
|
|
340
|
+
model = super().get_pytorch_model(model_id, *args, **kwargs)
|
|
326
341
|
|
|
327
342
|
return model
|
|
328
343
|
|
|
@@ -375,6 +390,9 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
375
390
|
if rbln_config.use_position_ids:
|
|
376
391
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
|
377
392
|
|
|
393
|
+
if rbln_config.use_lora:
|
|
394
|
+
input_info.append(("lora_int_ids", [batch_size], "int32"))
|
|
395
|
+
|
|
378
396
|
kvcache_dtype = rbln_config.torch_dtype
|
|
379
397
|
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
380
398
|
kvcache_dtype = "float8_e4m3fn"
|
|
@@ -667,6 +685,53 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
667
685
|
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
668
686
|
return is_prefill
|
|
669
687
|
|
|
688
|
+
def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
|
|
689
|
+
if isinstance(lora_int_ids, int):
|
|
690
|
+
lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
|
|
691
|
+
elif isinstance(lora_int_ids, list):
|
|
692
|
+
lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
|
|
693
|
+
|
|
694
|
+
self.lora_int_ids = lora_int_ids
|
|
695
|
+
|
|
696
|
+
self.prefill_decoder.lora_int_ids = lora_int_ids
|
|
697
|
+
if self.rbln_config.can_generate:
|
|
698
|
+
for batch_size in self.rbln_config.decoder_batch_sizes:
|
|
699
|
+
self.decoders[batch_size].lora_int_ids = lora_int_ids
|
|
700
|
+
|
|
701
|
+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
|
702
|
+
"""
|
|
703
|
+
Sets the active adapter(s) for the model using adapter name(s).
|
|
704
|
+
|
|
705
|
+
Args:
|
|
706
|
+
adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
|
|
707
|
+
Can be a single adapter name or a list of adapter names.
|
|
708
|
+
|
|
709
|
+
Raises:
|
|
710
|
+
ValueError: If the model is not configured with LoRA or if the adapter name is not found.
|
|
711
|
+
"""
|
|
712
|
+
if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
|
|
713
|
+
raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
|
|
714
|
+
|
|
715
|
+
# Convert single adapter name to list for uniform processing
|
|
716
|
+
if isinstance(adapter_name, str):
|
|
717
|
+
adapter_names = [adapter_name]
|
|
718
|
+
else:
|
|
719
|
+
adapter_names = adapter_name
|
|
720
|
+
|
|
721
|
+
# Validate that all adapter names exist
|
|
722
|
+
available_adapters = {
|
|
723
|
+
adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
|
|
724
|
+
}
|
|
725
|
+
missing_adapters = [name for name in adapter_names if name not in available_adapters]
|
|
726
|
+
if missing_adapters:
|
|
727
|
+
raise ValueError(
|
|
728
|
+
f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
# Get the adapter IDs and set them
|
|
732
|
+
lora_int_ids = [available_adapters[name] for name in adapter_names]
|
|
733
|
+
self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
|
|
734
|
+
|
|
670
735
|
def forward(
|
|
671
736
|
self,
|
|
672
737
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -677,6 +742,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
677
742
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
678
743
|
position_ids: Optional[torch.Tensor] = None,
|
|
679
744
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
745
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
680
746
|
return_dict: Optional[torch.Tensor] = None,
|
|
681
747
|
**kwargs,
|
|
682
748
|
) -> Tuple[torch.FloatTensor]:
|
|
@@ -684,6 +750,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
684
750
|
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
|
685
751
|
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
686
752
|
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
753
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
754
|
+
if self.lora_int_ids is None:
|
|
755
|
+
raise ValueError(
|
|
756
|
+
"lora_int_id is required when using LoRA. "
|
|
757
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
758
|
+
)
|
|
759
|
+
lora_int_ids = self.lora_int_ids
|
|
687
760
|
|
|
688
761
|
# for only use forward
|
|
689
762
|
if generate_idx is None:
|
|
@@ -708,6 +781,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
708
781
|
cache_position=cache_position,
|
|
709
782
|
batch_idx=b_idx,
|
|
710
783
|
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
784
|
+
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
711
785
|
)
|
|
712
786
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
713
787
|
logits.append(output.logits)
|
|
@@ -727,6 +801,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
727
801
|
inputs_embeds=inputs_embeds,
|
|
728
802
|
cache_position=cache_position,
|
|
729
803
|
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
804
|
+
lora_int_ids=lora_int_ids,
|
|
730
805
|
).logits
|
|
731
806
|
|
|
732
807
|
if not return_dict:
|
|
@@ -63,6 +63,7 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
63
63
|
rotary_emb: torch.nn.Module = None,
|
|
64
64
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
65
65
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
66
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
66
67
|
):
|
|
67
68
|
# retrieve input_ids and inputs_embeds
|
|
68
69
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -105,6 +106,7 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
105
106
|
cos=cos_local if is_sliding else cos_global,
|
|
106
107
|
sin=sin_local if is_sliding else sin_global,
|
|
107
108
|
block_tables=local_block_tables if is_sliding else global_block_tables,
|
|
109
|
+
lora_int_id=lora_int_id,
|
|
108
110
|
)
|
|
109
111
|
|
|
110
112
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
@@ -127,12 +129,20 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
|
127
129
|
cos: Optional[torch.Tensor] = None,
|
|
128
130
|
sin: Optional[torch.Tensor] = None,
|
|
129
131
|
block_tables: Optional[torch.Tensor] = None,
|
|
132
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
130
133
|
):
|
|
131
134
|
residual = hidden_states
|
|
132
135
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
133
136
|
|
|
134
137
|
hidden_states = self.self_attn(
|
|
135
|
-
hidden_states,
|
|
138
|
+
hidden_states=hidden_states,
|
|
139
|
+
attention_mask=attention_mask,
|
|
140
|
+
seq_positions=seq_positions,
|
|
141
|
+
past_key_values=past_key_values,
|
|
142
|
+
cos=cos,
|
|
143
|
+
sin=sin,
|
|
144
|
+
block_tables=block_tables,
|
|
145
|
+
lora_int_id=lora_int_id,
|
|
136
146
|
)
|
|
137
147
|
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
138
148
|
hidden_states = residual + hidden_states
|
|
@@ -140,7 +150,7 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
|
140
150
|
# Fully Connected
|
|
141
151
|
residual = hidden_states
|
|
142
152
|
hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
|
|
143
|
-
hidden_states = self.
|
|
153
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
144
154
|
hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
|
|
145
155
|
hidden_states = residual + hidden_states
|
|
146
156
|
|
|
@@ -17,15 +17,16 @@ import rebel
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
19
|
from ...modeling_outputs import RBLNDecoderOnlyOutput, RBLNGemma3ForCausalLMOutput
|
|
20
|
+
from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
|
|
20
21
|
from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
24
25
|
def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
|
|
25
26
|
super().__init__(*args, **kwargs)
|
|
26
|
-
self.image_prefill = image_prefill # FIXME(taehoon)
|
|
27
|
-
self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
|
|
28
|
-
self.decode = self.runtime if self.phase == "decode" else None
|
|
27
|
+
self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
|
|
28
|
+
self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
|
|
29
|
+
self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
|
|
29
30
|
|
|
30
31
|
def _prepare_prefill_inputs(self, *args, **kwargs):
|
|
31
32
|
(
|
|
@@ -73,12 +74,24 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
73
74
|
position_embed: Optional[torch.Tensor] = None,
|
|
74
75
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
75
76
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
77
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
76
78
|
) -> torch.FloatTensor:
|
|
77
79
|
"""
|
|
78
80
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
79
81
|
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
80
82
|
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
81
83
|
"""
|
|
84
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
85
|
+
if self.lora_int_ids is None:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"lora_int_id is required when using LoRA. "
|
|
88
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
89
|
+
)
|
|
90
|
+
if batch_idx is not None:
|
|
91
|
+
lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
|
|
92
|
+
else:
|
|
93
|
+
lora_int_ids = self.lora_int_ids.clone()
|
|
94
|
+
|
|
82
95
|
(
|
|
83
96
|
inputs,
|
|
84
97
|
cache_position,
|
|
@@ -141,6 +154,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
141
154
|
query_position,
|
|
142
155
|
chunked_attention_mask,
|
|
143
156
|
position_ids_chunk,
|
|
157
|
+
lora_int_ids if self.rbln_config.use_lora else None,
|
|
144
158
|
)
|
|
145
159
|
else:
|
|
146
160
|
logits = self.prefill(
|
|
@@ -151,6 +165,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
151
165
|
query_position,
|
|
152
166
|
chunked_attention_mask,
|
|
153
167
|
position_ids_chunk,
|
|
168
|
+
lora_int_ids if self.rbln_config.use_lora else None,
|
|
154
169
|
)
|
|
155
170
|
|
|
156
171
|
padded_cache_lengths += current_padded_cache_lengths
|
|
@@ -173,7 +188,20 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
173
188
|
position_embed: Optional[torch.Tensor] = None,
|
|
174
189
|
position_ids: Optional[torch.Tensor] = None,
|
|
175
190
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
191
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
176
192
|
) -> torch.FloatTensor:
|
|
193
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
194
|
+
if self.lora_int_ids is None:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
"lora_int_id is required when using LoRA. "
|
|
197
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
lora_int_ids = self.lora_int_ids
|
|
201
|
+
|
|
202
|
+
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
203
|
+
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
204
|
+
|
|
177
205
|
batch_size = inputs.shape[0]
|
|
178
206
|
if batch_size != self.batch_size:
|
|
179
207
|
raise RuntimeError(
|
|
@@ -28,6 +28,7 @@ from ....modeling import RBLNModel
|
|
|
28
28
|
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
29
29
|
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
30
30
|
from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
|
|
31
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
31
32
|
from ..decoderonly.modeling_decoderonly import (
|
|
32
33
|
RBLNDecoderOnlyModelForCausalLM,
|
|
33
34
|
)
|
|
@@ -77,7 +78,7 @@ class LoopProjector(LoopProcessor):
|
|
|
77
78
|
return output[0]
|
|
78
79
|
|
|
79
80
|
|
|
80
|
-
class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
81
|
+
class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
81
82
|
auto_model_class = AutoModelForImageTextToText
|
|
82
83
|
_rbln_submodules = [
|
|
83
84
|
{"name": "vision_tower"},
|
|
@@ -408,6 +409,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
408
409
|
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
409
410
|
sliding_window = getattr(model_config, "sliding_window", None)
|
|
410
411
|
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
|
412
|
+
if sliding_window_pattern is None:
|
|
413
|
+
if hasattr(model_config, "layer_types"):
|
|
414
|
+
first_full_attention_index = model_config.layer_types.index("full_attention")
|
|
415
|
+
sliding_window_pattern = first_full_attention_index + 1
|
|
416
|
+
else:
|
|
417
|
+
raise ValueError("Cannot determine sliding_window_pattern from model_config")
|
|
418
|
+
|
|
411
419
|
if sliding_window_pattern <= model_config.num_hidden_layers:
|
|
412
420
|
rbln_config.cache_impl = "hybrid"
|
|
413
421
|
rbln_config.sliding_window = sliding_window
|
|
@@ -75,7 +75,10 @@ class GPT2Attention(DecoderOnlyAttention):
|
|
|
75
75
|
self.o_proj = self._original_mod.c_proj
|
|
76
76
|
self.split_size = self._original_mod.split_size
|
|
77
77
|
|
|
78
|
-
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
78
|
+
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
79
|
+
if lora_int_id is not None:
|
|
80
|
+
raise NotImplementedError("LoRA is not supported for GPT2Attention")
|
|
81
|
+
|
|
79
82
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
80
83
|
return query_states, key_states, value_states
|
|
81
84
|
|
|
@@ -35,6 +35,7 @@ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
|
35
35
|
from ....modeling import RBLNModel
|
|
36
36
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
37
37
|
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
38
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
if TYPE_CHECKING:
|
|
@@ -120,9 +121,6 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
|
|
|
120
121
|
encoder_outputs = self.encoder(
|
|
121
122
|
inputs_embeds=hidden_states,
|
|
122
123
|
attention_mask=patch_attention_mask,
|
|
123
|
-
output_attentions=None,
|
|
124
|
-
output_hidden_states=None,
|
|
125
|
-
return_dict=False,
|
|
126
124
|
)
|
|
127
125
|
last_hidden_state = encoder_outputs[0]
|
|
128
126
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
@@ -185,7 +183,7 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
|
|
|
185
183
|
return BaseModelOutput(last_hidden_state=last_hidden_state)
|
|
186
184
|
|
|
187
185
|
|
|
188
|
-
class RBLNIdefics3ForConditionalGeneration(RBLNModel):
|
|
186
|
+
class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
189
187
|
"""
|
|
190
188
|
RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
191
189
|
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
@@ -27,6 +27,7 @@ from ....modeling import RBLNModel
|
|
|
27
27
|
from ....utils.logging import get_logger
|
|
28
28
|
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
29
29
|
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
30
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
logger = get_logger(__name__)
|
|
@@ -103,7 +104,7 @@ class LoopProjector(LoopProcessor):
|
|
|
103
104
|
return output[0]
|
|
104
105
|
|
|
105
106
|
|
|
106
|
-
class RBLNLlavaForConditionalGeneration(RBLNModel):
|
|
107
|
+
class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
107
108
|
"""
|
|
108
109
|
RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
|
|
109
110
|
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
@@ -32,6 +32,7 @@ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
|
32
32
|
from ....modeling import RBLNModel
|
|
33
33
|
from ....utils.logging import get_logger
|
|
34
34
|
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
35
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
35
36
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
|
|
36
37
|
|
|
37
38
|
|
|
@@ -87,7 +88,7 @@ class LoopProjector(LoopProcessor):
|
|
|
87
88
|
return output[0]
|
|
88
89
|
|
|
89
90
|
|
|
90
|
-
class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
91
|
+
class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
91
92
|
"""
|
|
92
93
|
RBLNLlavaNextForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
|
|
93
94
|
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|