optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import safetensors.torch
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from ....utils import logging
|
|
10
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
logger = logging.get_logger()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LoRALinear(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
A linear layer that supports multiple LoRA adapters compiled at static time.
|
|
19
|
+
|
|
20
|
+
This class replaces the original linear layer and handles both base weights
|
|
21
|
+
and multiple LoRA adapters in a single forward pass using custom ops.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
original_linear: nn.Linear,
|
|
27
|
+
lora_config: RBLNLoRAConfig,
|
|
28
|
+
projection_name: str = "",
|
|
29
|
+
layer_idx: int = 0,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
original_linear: The original linear layer to be replaced
|
|
34
|
+
lora_config: LoRA configuration containing all adapters
|
|
35
|
+
projection_name: Name of the projection (e.g., "q_proj", "k_proj")
|
|
36
|
+
layer_idx: Layer index for loading the correct LoRA weights
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self.in_features = original_linear.in_features
|
|
41
|
+
self.out_features = original_linear.out_features
|
|
42
|
+
self.projection_name = projection_name
|
|
43
|
+
self.layer_idx = layer_idx
|
|
44
|
+
self.lora_config = lora_config
|
|
45
|
+
|
|
46
|
+
# Store original linear weights and bias directly without cloning
|
|
47
|
+
self.register_buffer("weight", original_linear.weight.data)
|
|
48
|
+
if original_linear.bias is not None:
|
|
49
|
+
self.register_buffer("bias", original_linear.bias.data)
|
|
50
|
+
else:
|
|
51
|
+
self.bias = None
|
|
52
|
+
|
|
53
|
+
# Initialize LoRA weights
|
|
54
|
+
self._init_lora_weights()
|
|
55
|
+
|
|
56
|
+
def _should_apply_lora(self) -> bool:
|
|
57
|
+
"""Check if this projection should have LoRA applied."""
|
|
58
|
+
# Check if any adapter targets this projection
|
|
59
|
+
return any(self.projection_name in adapter.target_modules for adapter in self.lora_config.adapters)
|
|
60
|
+
|
|
61
|
+
def _load_adapter_weights(self, adapter_path: Path):
|
|
62
|
+
"""
|
|
63
|
+
Load adapter weights from local directory.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
adapter_path: Path to local directory containing adapter weights
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary containing adapter weights
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
FileNotFoundError: If no adapter weights are found in the directory
|
|
73
|
+
"""
|
|
74
|
+
if not adapter_path.is_dir():
|
|
75
|
+
raise ValueError(f"Adapter path must be a directory, got: {adapter_path}")
|
|
76
|
+
|
|
77
|
+
# Try to load weights in order of preference
|
|
78
|
+
weight_files = [
|
|
79
|
+
("adapter_model.safetensors", lambda p: safetensors.torch.load_file(p)),
|
|
80
|
+
("adapter_model.bin", lambda p: torch.load(p, map_location="cpu")),
|
|
81
|
+
("pytorch_model.bin", lambda p: torch.load(p, map_location="cpu")),
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
for filename, load_fn in weight_files:
|
|
85
|
+
weight_path = adapter_path / filename
|
|
86
|
+
if weight_path.exists():
|
|
87
|
+
return load_fn(weight_path)
|
|
88
|
+
|
|
89
|
+
raise FileNotFoundError(
|
|
90
|
+
f"No adapter weights found in {adapter_path}. "
|
|
91
|
+
f"Expected one of: {', '.join(filename for filename, _ in weight_files)}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _init_lora_weights(self):
|
|
95
|
+
"""Initialize LoRA adapter weights by loading and stacking them."""
|
|
96
|
+
|
|
97
|
+
lora_a_weights = []
|
|
98
|
+
lora_b_weights = []
|
|
99
|
+
|
|
100
|
+
for adapter in self.lora_config.adapters:
|
|
101
|
+
if self.projection_name not in adapter.target_modules:
|
|
102
|
+
# Create zero weights for adapters that don't target this projection
|
|
103
|
+
lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
|
|
104
|
+
lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
adapter_weights = self._load_adapter_weights(adapter.local_adapter_path)
|
|
108
|
+
|
|
109
|
+
# Determine module type from projection name
|
|
110
|
+
attn_projs = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
|
111
|
+
mlp_projs = {"gate_proj", "up_proj", "down_proj"}
|
|
112
|
+
if self.projection_name in attn_projs:
|
|
113
|
+
module_type = "self_attn"
|
|
114
|
+
elif self.projection_name in mlp_projs:
|
|
115
|
+
module_type = "mlp"
|
|
116
|
+
else:
|
|
117
|
+
module_type = "self_attn"
|
|
118
|
+
|
|
119
|
+
layer_key = f"base_model.model.model.layers.{self.layer_idx}.{module_type}.{self.projection_name}"
|
|
120
|
+
lora_a_key = f"{layer_key}.lora_A.weight"
|
|
121
|
+
lora_b_key = f"{layer_key}.lora_B.weight"
|
|
122
|
+
|
|
123
|
+
if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
|
|
124
|
+
# Calculate scaling factor and fold it into lora_b
|
|
125
|
+
scaling = adapter.lora_alpha / adapter.r
|
|
126
|
+
if adapter.use_rslora:
|
|
127
|
+
scaling = scaling / math.sqrt(adapter.r)
|
|
128
|
+
scaling = scaling * adapter.scaling_factor
|
|
129
|
+
|
|
130
|
+
lora_a_weights.append(adapter_weights[lora_a_key])
|
|
131
|
+
# scaling is pre-applied to lora_b_weights
|
|
132
|
+
lora_b_weights.append(adapter_weights[lora_b_key] * scaling)
|
|
133
|
+
else:
|
|
134
|
+
logger.warning(f"No LoRA weights found for {lora_a_key} or {lora_b_key}")
|
|
135
|
+
lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
|
|
136
|
+
lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
|
|
137
|
+
|
|
138
|
+
# Stack weights along adapter dimension
|
|
139
|
+
max_rank = self.lora_config.max_lora_rank
|
|
140
|
+
|
|
141
|
+
# Pad smaller ranks to max_rank
|
|
142
|
+
padded_lora_a = []
|
|
143
|
+
padded_lora_b = []
|
|
144
|
+
|
|
145
|
+
for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
|
|
146
|
+
current_rank = lora_a.shape[0]
|
|
147
|
+
if current_rank < max_rank:
|
|
148
|
+
# Pad with zeros
|
|
149
|
+
padded_a = torch.zeros(max_rank, self.in_features)
|
|
150
|
+
padded_b = torch.zeros(self.out_features, max_rank)
|
|
151
|
+
padded_a[:current_rank] = lora_a
|
|
152
|
+
padded_b[:, :current_rank] = lora_b
|
|
153
|
+
padded_lora_a.append(padded_a)
|
|
154
|
+
padded_lora_b.append(padded_b)
|
|
155
|
+
else:
|
|
156
|
+
padded_lora_a.append(lora_a)
|
|
157
|
+
padded_lora_b.append(lora_b)
|
|
158
|
+
|
|
159
|
+
lora_a_transposed = [lora_a.transpose(0, 1) for lora_a in padded_lora_a] # [in_features, rank]
|
|
160
|
+
lora_b_transposed = [lora_b.transpose(0, 1) for lora_b in padded_lora_b] # [rank, out_features]
|
|
161
|
+
|
|
162
|
+
self.register_buffer(
|
|
163
|
+
"lora_a_weights", torch.stack(lora_a_transposed, dim=0).to(self.weight.dtype)
|
|
164
|
+
) # [num_adapters, in_features, rank]
|
|
165
|
+
self.register_buffer(
|
|
166
|
+
"lora_b_weights", torch.stack(lora_b_transposed, dim=0).to(self.weight.dtype)
|
|
167
|
+
) # [num_adapters, rank, out_features]
|
|
168
|
+
|
|
169
|
+
def forward(self, x: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
170
|
+
"""
|
|
171
|
+
Forward pass that combines base linear transformation with LoRA.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
x: Input tensor [batch_size, seq_len, in_features]
|
|
175
|
+
lora_int_id: Adapter ID tensor [batch_size] indicating which adapter to use
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Output tensor [batch_size, seq_len, out_features]
|
|
179
|
+
"""
|
|
180
|
+
# Base linear transformation
|
|
181
|
+
output = torch.nn.functional.linear(x, self.weight, self.bias)
|
|
182
|
+
|
|
183
|
+
# Apply LoRA if enabled and adapter ID is provided
|
|
184
|
+
if self._should_apply_lora() and lora_int_id is not None:
|
|
185
|
+
# Gather LoRA weights for each batch item
|
|
186
|
+
# lora_int_id: [batch_size] -> use as indices to select weights
|
|
187
|
+
selected_lora_a = self.lora_a_weights[lora_int_id] # [batch_size, in_features, rank]
|
|
188
|
+
selected_lora_b = self.lora_b_weights[lora_int_id] # [batch_size, rank, out_features]
|
|
189
|
+
|
|
190
|
+
# Batched matrix multiplication for LoRA computation
|
|
191
|
+
# x: [batch_size, seq_len, in_features]
|
|
192
|
+
# selected_lora_a: [batch_size, in_features, rank] (already transposed)
|
|
193
|
+
# selected_lora_b: [batch_size, rank, out_features] (already transposed)
|
|
194
|
+
|
|
195
|
+
# First matmul: x @ lora_a -> [batch_size, seq_len, rank]
|
|
196
|
+
temp = torch.bmm(x, selected_lora_a)
|
|
197
|
+
|
|
198
|
+
# Second matmul: temp @ lora_b -> [batch_size, seq_len, out_features]
|
|
199
|
+
lora_delta = torch.bmm(temp, selected_lora_b)
|
|
200
|
+
|
|
201
|
+
# Add LoRA delta to base output
|
|
202
|
+
output = output + lora_delta
|
|
203
|
+
|
|
204
|
+
return output
|