optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,204 @@
1
+ import math
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import safetensors.torch
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ....utils import logging
10
+ from .configuration_lora import RBLNLoRAConfig
11
+
12
+
13
+ logger = logging.get_logger()
14
+
15
+
16
+ class LoRALinear(nn.Module):
17
+ """
18
+ A linear layer that supports multiple LoRA adapters compiled at static time.
19
+
20
+ This class replaces the original linear layer and handles both base weights
21
+ and multiple LoRA adapters in a single forward pass using custom ops.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ original_linear: nn.Linear,
27
+ lora_config: RBLNLoRAConfig,
28
+ projection_name: str = "",
29
+ layer_idx: int = 0,
30
+ ):
31
+ """
32
+ Args:
33
+ original_linear: The original linear layer to be replaced
34
+ lora_config: LoRA configuration containing all adapters
35
+ projection_name: Name of the projection (e.g., "q_proj", "k_proj")
36
+ layer_idx: Layer index for loading the correct LoRA weights
37
+ """
38
+ super().__init__()
39
+
40
+ self.in_features = original_linear.in_features
41
+ self.out_features = original_linear.out_features
42
+ self.projection_name = projection_name
43
+ self.layer_idx = layer_idx
44
+ self.lora_config = lora_config
45
+
46
+ # Store original linear weights and bias directly without cloning
47
+ self.register_buffer("weight", original_linear.weight.data)
48
+ if original_linear.bias is not None:
49
+ self.register_buffer("bias", original_linear.bias.data)
50
+ else:
51
+ self.bias = None
52
+
53
+ # Initialize LoRA weights
54
+ self._init_lora_weights()
55
+
56
+ def _should_apply_lora(self) -> bool:
57
+ """Check if this projection should have LoRA applied."""
58
+ # Check if any adapter targets this projection
59
+ return any(self.projection_name in adapter.target_modules for adapter in self.lora_config.adapters)
60
+
61
+ def _load_adapter_weights(self, adapter_path: Path):
62
+ """
63
+ Load adapter weights from local directory.
64
+
65
+ Args:
66
+ adapter_path: Path to local directory containing adapter weights
67
+
68
+ Returns:
69
+ Dictionary containing adapter weights
70
+
71
+ Raises:
72
+ FileNotFoundError: If no adapter weights are found in the directory
73
+ """
74
+ if not adapter_path.is_dir():
75
+ raise ValueError(f"Adapter path must be a directory, got: {adapter_path}")
76
+
77
+ # Try to load weights in order of preference
78
+ weight_files = [
79
+ ("adapter_model.safetensors", lambda p: safetensors.torch.load_file(p)),
80
+ ("adapter_model.bin", lambda p: torch.load(p, map_location="cpu")),
81
+ ("pytorch_model.bin", lambda p: torch.load(p, map_location="cpu")),
82
+ ]
83
+
84
+ for filename, load_fn in weight_files:
85
+ weight_path = adapter_path / filename
86
+ if weight_path.exists():
87
+ return load_fn(weight_path)
88
+
89
+ raise FileNotFoundError(
90
+ f"No adapter weights found in {adapter_path}. "
91
+ f"Expected one of: {', '.join(filename for filename, _ in weight_files)}"
92
+ )
93
+
94
+ def _init_lora_weights(self):
95
+ """Initialize LoRA adapter weights by loading and stacking them."""
96
+
97
+ lora_a_weights = []
98
+ lora_b_weights = []
99
+
100
+ for adapter in self.lora_config.adapters:
101
+ if self.projection_name not in adapter.target_modules:
102
+ # Create zero weights for adapters that don't target this projection
103
+ lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
104
+ lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
105
+ continue
106
+
107
+ adapter_weights = self._load_adapter_weights(adapter.local_adapter_path)
108
+
109
+ # Determine module type from projection name
110
+ attn_projs = {"q_proj", "k_proj", "v_proj", "o_proj"}
111
+ mlp_projs = {"gate_proj", "up_proj", "down_proj"}
112
+ if self.projection_name in attn_projs:
113
+ module_type = "self_attn"
114
+ elif self.projection_name in mlp_projs:
115
+ module_type = "mlp"
116
+ else:
117
+ module_type = "self_attn"
118
+
119
+ layer_key = f"base_model.model.model.layers.{self.layer_idx}.{module_type}.{self.projection_name}"
120
+ lora_a_key = f"{layer_key}.lora_A.weight"
121
+ lora_b_key = f"{layer_key}.lora_B.weight"
122
+
123
+ if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
124
+ # Calculate scaling factor and fold it into lora_b
125
+ scaling = adapter.lora_alpha / adapter.r
126
+ if adapter.use_rslora:
127
+ scaling = scaling / math.sqrt(adapter.r)
128
+ scaling = scaling * adapter.scaling_factor
129
+
130
+ lora_a_weights.append(adapter_weights[lora_a_key])
131
+ # scaling is pre-applied to lora_b_weights
132
+ lora_b_weights.append(adapter_weights[lora_b_key] * scaling)
133
+ else:
134
+ logger.warning(f"No LoRA weights found for {lora_a_key} or {lora_b_key}")
135
+ lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
136
+ lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
137
+
138
+ # Stack weights along adapter dimension
139
+ max_rank = self.lora_config.max_lora_rank
140
+
141
+ # Pad smaller ranks to max_rank
142
+ padded_lora_a = []
143
+ padded_lora_b = []
144
+
145
+ for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
146
+ current_rank = lora_a.shape[0]
147
+ if current_rank < max_rank:
148
+ # Pad with zeros
149
+ padded_a = torch.zeros(max_rank, self.in_features)
150
+ padded_b = torch.zeros(self.out_features, max_rank)
151
+ padded_a[:current_rank] = lora_a
152
+ padded_b[:, :current_rank] = lora_b
153
+ padded_lora_a.append(padded_a)
154
+ padded_lora_b.append(padded_b)
155
+ else:
156
+ padded_lora_a.append(lora_a)
157
+ padded_lora_b.append(lora_b)
158
+
159
+ lora_a_transposed = [lora_a.transpose(0, 1) for lora_a in padded_lora_a] # [in_features, rank]
160
+ lora_b_transposed = [lora_b.transpose(0, 1) for lora_b in padded_lora_b] # [rank, out_features]
161
+
162
+ self.register_buffer(
163
+ "lora_a_weights", torch.stack(lora_a_transposed, dim=0).to(self.weight.dtype)
164
+ ) # [num_adapters, in_features, rank]
165
+ self.register_buffer(
166
+ "lora_b_weights", torch.stack(lora_b_transposed, dim=0).to(self.weight.dtype)
167
+ ) # [num_adapters, rank, out_features]
168
+
169
+ def forward(self, x: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
170
+ """
171
+ Forward pass that combines base linear transformation with LoRA.
172
+
173
+ Args:
174
+ x: Input tensor [batch_size, seq_len, in_features]
175
+ lora_int_id: Adapter ID tensor [batch_size] indicating which adapter to use
176
+
177
+ Returns:
178
+ Output tensor [batch_size, seq_len, out_features]
179
+ """
180
+ # Base linear transformation
181
+ output = torch.nn.functional.linear(x, self.weight, self.bias)
182
+
183
+ # Apply LoRA if enabled and adapter ID is provided
184
+ if self._should_apply_lora() and lora_int_id is not None:
185
+ # Gather LoRA weights for each batch item
186
+ # lora_int_id: [batch_size] -> use as indices to select weights
187
+ selected_lora_a = self.lora_a_weights[lora_int_id] # [batch_size, in_features, rank]
188
+ selected_lora_b = self.lora_b_weights[lora_int_id] # [batch_size, rank, out_features]
189
+
190
+ # Batched matrix multiplication for LoRA computation
191
+ # x: [batch_size, seq_len, in_features]
192
+ # selected_lora_a: [batch_size, in_features, rank] (already transposed)
193
+ # selected_lora_b: [batch_size, rank, out_features] (already transposed)
194
+
195
+ # First matmul: x @ lora_a -> [batch_size, seq_len, rank]
196
+ temp = torch.bmm(x, selected_lora_a)
197
+
198
+ # Second matmul: temp @ lora_b -> [batch_size, seq_len, out_features]
199
+ lora_delta = torch.bmm(temp, selected_lora_b)
200
+
201
+ # Add LoRA delta to base output
202
+ output = output + lora_delta
203
+
204
+ return output