optimum-rbln 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +63 -32
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +30 -14
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +11 -8
  7. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +23 -13
  8. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +10 -6
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +14 -10
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +14 -7
  11. optimum/rbln/diffusers/modeling_diffusers.py +5 -7
  12. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +9 -11
  13. optimum/rbln/modeling.py +50 -0
  14. optimum/rbln/modeling_base.py +1 -2
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_generic.py +37 -1
  17. optimum/rbln/transformers/models/__init__.py +9 -0
  18. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +35 -3
  19. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +86 -23
  20. optimum/rbln/transformers/models/clip/modeling_clip.py +4 -0
  21. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  22. optimum/rbln/transformers/models/colpali/configuration_colpali.py +34 -18
  23. optimum/rbln/transformers/models/colpali/modeling_colpali.py +73 -80
  24. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  25. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  26. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  27. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  28. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  29. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  30. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  32. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +50 -2
  33. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  34. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +65 -3
  35. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +11 -3
  36. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  37. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  38. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +67 -44
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +27 -3
  41. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +24 -19
  42. optimum/rbln/transformers/models/llava/configuration_llava.py +16 -2
  43. optimum/rbln/transformers/models/llava/modeling_llava.py +108 -50
  44. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +11 -13
  45. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -343
  46. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  47. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +6 -11
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +9 -8
  50. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +24 -0
  51. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +24 -0
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  54. optimum/rbln/transformers/models/siglip/modeling_siglip.py +3 -14
  55. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  57. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  58. optimum/rbln/utils/runtime_utils.py +25 -15
  59. optimum/rbln/utils/submodule.py +21 -5
  60. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2.dist-info}/METADATA +5 -5
  61. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2.dist-info}/RECORD +64 -55
  62. optimum_rbln-0.9.2.dist-info/entry_points.txt +2 -0
  63. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2.dist-info}/licenses/LICENSE +0 -0
@@ -187,6 +187,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
187
187
  torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
188
188
  )
189
189
 
190
+ self.lora_int_ids = None
191
+
190
192
  def inputs_embeddings_if_needed(
191
193
  self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
192
194
  ):
@@ -210,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
210
212
  position_ids: Optional[torch.Tensor] = None,
211
213
  token_type_ids: Optional[torch.Tensor] = None,
212
214
  local_block_tables: Optional[torch.Tensor] = None,
215
+ lora_int_ids: Optional[torch.Tensor] = None,
213
216
  ):
214
217
  inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
215
218
  block_tables, local_block_tables, is_external_block_tables = (
@@ -233,6 +236,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
233
236
  position_embed=position_embed,
234
237
  position_ids=position_ids,
235
238
  local_block_tables=local_block_tables,
239
+ lora_int_ids=lora_int_ids,
236
240
  )
237
241
  else:
238
242
  return self.prefill_forward(
@@ -245,6 +249,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
245
249
  position_embed=position_embed,
246
250
  token_type_ids=token_type_ids,
247
251
  local_block_tables=local_block_tables,
252
+ lora_int_ids=lora_int_ids,
248
253
  )
249
254
 
250
255
  def decode_forward(
@@ -257,7 +262,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
257
262
  position_embed: Optional[torch.Tensor] = None,
258
263
  position_ids: Optional[torch.Tensor] = None,
259
264
  local_block_tables: Optional[torch.Tensor] = None,
265
+ lora_int_ids: Optional[torch.Tensor] = None,
260
266
  ) -> torch.FloatTensor:
267
+ if self.rbln_config.use_lora and lora_int_ids is None:
268
+ if self.lora_int_ids is None:
269
+ raise ValueError(
270
+ "lora_int_id is required when using LoRA. "
271
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
272
+ )
273
+
274
+ lora_int_ids = self.lora_int_ids
275
+
276
+ if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
277
+ raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
278
+
261
279
  if self.batch_size != cache_position.shape[0]:
262
280
  raise RuntimeError(
263
281
  f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
@@ -287,6 +305,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
287
305
  position_embed,
288
306
  attention_mask if self.rbln_config.use_attention_mask else None,
289
307
  position_ids if self.rbln_config.use_position_ids else None,
308
+ lora_int_ids if self.rbln_config.use_lora else None,
290
309
  )
291
310
 
292
311
  return RBLNDecoderOnlyOutput(logits=logits)
@@ -369,12 +388,25 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
369
388
  position_embed: Optional[torch.Tensor] = None,
370
389
  token_type_ids: Optional[torch.Tensor] = None,
371
390
  local_block_tables: Optional[torch.Tensor] = None,
391
+ lora_int_ids: Optional[torch.Tensor] = None,
372
392
  ) -> torch.FloatTensor:
373
393
  """
374
394
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
375
395
  Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
376
396
  and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
377
397
  """
398
+ if self.rbln_config.use_lora and lora_int_ids is None:
399
+ if self.lora_int_ids is None:
400
+ raise ValueError(
401
+ "lora_int_id is required when using LoRA. "
402
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
403
+ )
404
+
405
+ if batch_idx is not None:
406
+ lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
407
+ else:
408
+ lora_int_ids = self.lora_int_ids.clone()
409
+
378
410
  (
379
411
  inputs,
380
412
  cache_position,
@@ -388,6 +420,16 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
388
420
  inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
389
421
  )
390
422
 
423
+ # Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
424
+ prefix_cached_len = cache_position[0][0].item()
425
+ if prefix_cached_len > 0:
426
+ if prefix_cached_len % self.rbln_config.prefill_chunk_size != 0:
427
+ raise NotImplementedError(
428
+ "Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
429
+ )
430
+ if self.rbln_config.use_attention_mask:
431
+ chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
432
+
391
433
  # Process input in chunks of size `prefill_chunk_size`
392
434
  output_logits = []
393
435
  for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
@@ -402,9 +444,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
402
444
  if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
403
445
  if step > 0: # update previous chunk
404
446
  chunked_attention_mask[
405
- :, :, :, s - self.rbln_config.prefill_chunk_size : e - self.rbln_config.prefill_chunk_size
447
+ :,
448
+ :,
449
+ :,
450
+ s - self.rbln_config.prefill_chunk_size + prefix_cached_len : e
451
+ - self.rbln_config.prefill_chunk_size
452
+ + prefix_cached_len,
406
453
  ] = 1
407
- chunked_attention_mask[:, :, :, s:e] = self.causal_mask
454
+ chunked_attention_mask[:, :, :, s + prefix_cached_len : e + prefix_cached_len] = self.causal_mask
408
455
 
409
456
  # Calculate query position if needed
410
457
  if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
@@ -426,6 +473,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
426
473
  query_position,
427
474
  chunked_attention_mask if self.rbln_config.use_attention_mask else None,
428
475
  position_ids_chunk,
476
+ lora_int_ids if self.rbln_config.use_lora else None,
429
477
  out=self.out_buffers,
430
478
  )
431
479
  output_logits.append(output_logit)
@@ -0,0 +1,204 @@
1
+ import math
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import safetensors.torch
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ....utils import logging
10
+ from .configuration_lora import RBLNLoRAConfig
11
+
12
+
13
+ logger = logging.get_logger()
14
+
15
+
16
+ class LoRALinear(nn.Module):
17
+ """
18
+ A linear layer that supports multiple LoRA adapters compiled at static time.
19
+
20
+ This class replaces the original linear layer and handles both base weights
21
+ and multiple LoRA adapters in a single forward pass using custom ops.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ original_linear: nn.Linear,
27
+ lora_config: RBLNLoRAConfig,
28
+ projection_name: str = "",
29
+ layer_idx: int = 0,
30
+ ):
31
+ """
32
+ Args:
33
+ original_linear: The original linear layer to be replaced
34
+ lora_config: LoRA configuration containing all adapters
35
+ projection_name: Name of the projection (e.g., "q_proj", "k_proj")
36
+ layer_idx: Layer index for loading the correct LoRA weights
37
+ """
38
+ super().__init__()
39
+
40
+ self.in_features = original_linear.in_features
41
+ self.out_features = original_linear.out_features
42
+ self.projection_name = projection_name
43
+ self.layer_idx = layer_idx
44
+ self.lora_config = lora_config
45
+
46
+ # Store original linear weights and bias directly without cloning
47
+ self.register_buffer("weight", original_linear.weight.data)
48
+ if original_linear.bias is not None:
49
+ self.register_buffer("bias", original_linear.bias.data)
50
+ else:
51
+ self.bias = None
52
+
53
+ # Initialize LoRA weights
54
+ self._init_lora_weights()
55
+
56
+ def _should_apply_lora(self) -> bool:
57
+ """Check if this projection should have LoRA applied."""
58
+ # Check if any adapter targets this projection
59
+ return any(self.projection_name in adapter.target_modules for adapter in self.lora_config.adapters)
60
+
61
+ def _load_adapter_weights(self, adapter_path: Path):
62
+ """
63
+ Load adapter weights from local directory.
64
+
65
+ Args:
66
+ adapter_path: Path to local directory containing adapter weights
67
+
68
+ Returns:
69
+ Dictionary containing adapter weights
70
+
71
+ Raises:
72
+ FileNotFoundError: If no adapter weights are found in the directory
73
+ """
74
+ if not adapter_path.is_dir():
75
+ raise ValueError(f"Adapter path must be a directory, got: {adapter_path}")
76
+
77
+ # Try to load weights in order of preference
78
+ weight_files = [
79
+ ("adapter_model.safetensors", lambda p: safetensors.torch.load_file(p)),
80
+ ("adapter_model.bin", lambda p: torch.load(p, map_location="cpu")),
81
+ ("pytorch_model.bin", lambda p: torch.load(p, map_location="cpu")),
82
+ ]
83
+
84
+ for filename, load_fn in weight_files:
85
+ weight_path = adapter_path / filename
86
+ if weight_path.exists():
87
+ return load_fn(weight_path)
88
+
89
+ raise FileNotFoundError(
90
+ f"No adapter weights found in {adapter_path}. "
91
+ f"Expected one of: {', '.join(filename for filename, _ in weight_files)}"
92
+ )
93
+
94
+ def _init_lora_weights(self):
95
+ """Initialize LoRA adapter weights by loading and stacking them."""
96
+
97
+ lora_a_weights = []
98
+ lora_b_weights = []
99
+
100
+ for adapter in self.lora_config.adapters:
101
+ if self.projection_name not in adapter.target_modules:
102
+ # Create zero weights for adapters that don't target this projection
103
+ lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
104
+ lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
105
+ continue
106
+
107
+ adapter_weights = self._load_adapter_weights(adapter.local_adapter_path)
108
+
109
+ # Determine module type from projection name
110
+ attn_projs = {"q_proj", "k_proj", "v_proj", "o_proj"}
111
+ mlp_projs = {"gate_proj", "up_proj", "down_proj"}
112
+ if self.projection_name in attn_projs:
113
+ module_type = "self_attn"
114
+ elif self.projection_name in mlp_projs:
115
+ module_type = "mlp"
116
+ else:
117
+ module_type = "self_attn"
118
+
119
+ layer_key = f"base_model.model.model.layers.{self.layer_idx}.{module_type}.{self.projection_name}"
120
+ lora_a_key = f"{layer_key}.lora_A.weight"
121
+ lora_b_key = f"{layer_key}.lora_B.weight"
122
+
123
+ if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
124
+ # Calculate scaling factor and fold it into lora_b
125
+ scaling = adapter.lora_alpha / adapter.r
126
+ if adapter.use_rslora:
127
+ scaling = scaling / math.sqrt(adapter.r)
128
+ scaling = scaling * adapter.scaling_factor
129
+
130
+ lora_a_weights.append(adapter_weights[lora_a_key])
131
+ # scaling is pre-applied to lora_b_weights
132
+ lora_b_weights.append(adapter_weights[lora_b_key] * scaling)
133
+ else:
134
+ logger.warning(f"No LoRA weights found for {lora_a_key} or {lora_b_key}")
135
+ lora_a_weights.append(torch.zeros(adapter.r, self.in_features))
136
+ lora_b_weights.append(torch.zeros(self.out_features, adapter.r))
137
+
138
+ # Stack weights along adapter dimension
139
+ max_rank = self.lora_config.max_lora_rank
140
+
141
+ # Pad smaller ranks to max_rank
142
+ padded_lora_a = []
143
+ padded_lora_b = []
144
+
145
+ for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
146
+ current_rank = lora_a.shape[0]
147
+ if current_rank < max_rank:
148
+ # Pad with zeros
149
+ padded_a = torch.zeros(max_rank, self.in_features)
150
+ padded_b = torch.zeros(self.out_features, max_rank)
151
+ padded_a[:current_rank] = lora_a
152
+ padded_b[:, :current_rank] = lora_b
153
+ padded_lora_a.append(padded_a)
154
+ padded_lora_b.append(padded_b)
155
+ else:
156
+ padded_lora_a.append(lora_a)
157
+ padded_lora_b.append(lora_b)
158
+
159
+ lora_a_transposed = [lora_a.transpose(0, 1) for lora_a in padded_lora_a] # [in_features, rank]
160
+ lora_b_transposed = [lora_b.transpose(0, 1) for lora_b in padded_lora_b] # [rank, out_features]
161
+
162
+ self.register_buffer(
163
+ "lora_a_weights", torch.stack(lora_a_transposed, dim=0).to(self.weight.dtype)
164
+ ) # [num_adapters, in_features, rank]
165
+ self.register_buffer(
166
+ "lora_b_weights", torch.stack(lora_b_transposed, dim=0).to(self.weight.dtype)
167
+ ) # [num_adapters, rank, out_features]
168
+
169
+ def forward(self, x: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
170
+ """
171
+ Forward pass that combines base linear transformation with LoRA.
172
+
173
+ Args:
174
+ x: Input tensor [batch_size, seq_len, in_features]
175
+ lora_int_id: Adapter ID tensor [batch_size] indicating which adapter to use
176
+
177
+ Returns:
178
+ Output tensor [batch_size, seq_len, out_features]
179
+ """
180
+ # Base linear transformation
181
+ output = torch.nn.functional.linear(x, self.weight, self.bias)
182
+
183
+ # Apply LoRA if enabled and adapter ID is provided
184
+ if self._should_apply_lora() and lora_int_id is not None:
185
+ # Gather LoRA weights for each batch item
186
+ # lora_int_id: [batch_size] -> use as indices to select weights
187
+ selected_lora_a = self.lora_a_weights[lora_int_id] # [batch_size, in_features, rank]
188
+ selected_lora_b = self.lora_b_weights[lora_int_id] # [batch_size, rank, out_features]
189
+
190
+ # Batched matrix multiplication for LoRA computation
191
+ # x: [batch_size, seq_len, in_features]
192
+ # selected_lora_a: [batch_size, in_features, rank] (already transposed)
193
+ # selected_lora_b: [batch_size, rank, out_features] (already transposed)
194
+
195
+ # First matmul: x @ lora_a -> [batch_size, seq_len, rank]
196
+ temp = torch.bmm(x, selected_lora_a)
197
+
198
+ # Second matmul: temp @ lora_b -> [batch_size, seq_len, out_features]
199
+ lora_delta = torch.bmm(temp, selected_lora_b)
200
+
201
+ # Add LoRA delta to base output
202
+ output = output + lora_delta
203
+
204
+ return output
@@ -57,7 +57,6 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
57
57
  1. Converting pre-trained transformer models to RBLN-optimized format
58
58
  2. Handling the compilation process for RBLN devices
59
59
  3. Managing inference operations for decoder-only architectures
60
-
61
60
  This class inherits from RBLNModel and implements specific methods required for
62
61
  decoder-only architectures.
63
62
 
@@ -68,6 +67,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
68
67
  - The class handles RBLN-specific optimizations automatically during compilation
69
68
  """
70
69
 
70
+ _tp_support = True
71
+
71
72
  main_input_name = "input_ids"
72
73
  auto_model_class = AutoModel
73
74
  _decoder_wrapper_cls = DecoderOnlyWrapper
@@ -259,10 +260,12 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
259
260
 
260
261
  # Mark static tensors (self kv states)
261
262
  static_tensors = {}
263
+ idx = 0
262
264
  for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
263
265
  if "past_key_values" in name:
264
266
  static_tensors[name] = tensor
265
- context.mark_static_address(tensor)
267
+ context.mark_static_address(tensor, f"kv_cache_{idx}")
268
+ idx += 1
266
269
 
267
270
  return context, static_tensors
268
271
 
@@ -374,6 +377,9 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
374
377
  if rbln_config.use_position_ids:
375
378
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
376
379
 
380
+ if rbln_config.use_lora:
381
+ input_info.append(("lora_int_ids", [batch_size], "int32"))
382
+
377
383
  kvcache_dtype = rbln_config.torch_dtype
378
384
  if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
379
385
  kvcache_dtype = "float8_e4m3fn"
@@ -642,7 +648,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
642
648
  1. Converting pre-trained transformer models to RBLN-optimized format
643
649
  2. Handling the compilation process for RBLN devices
644
650
  3. Managing inference operations for causal language modeling
645
-
646
651
  This class inherits from RBLNModel and implements specific methods required for
647
652
  decoder-only architectures and causal language modeling tasks.
648
653
 
@@ -667,6 +672,53 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
667
672
  def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
668
673
  return is_prefill
669
674
 
675
+ def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
676
+ if isinstance(lora_int_ids, int):
677
+ lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
678
+ elif isinstance(lora_int_ids, list):
679
+ lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
680
+
681
+ self.lora_int_ids = lora_int_ids
682
+
683
+ self.prefill_decoder.lora_int_ids = lora_int_ids
684
+ if self.rbln_config.can_generate:
685
+ for batch_size in self.rbln_config.decoder_batch_sizes:
686
+ self.decoders[batch_size].lora_int_ids = lora_int_ids
687
+
688
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
689
+ """
690
+ Sets the active adapter(s) for the model using adapter name(s).
691
+
692
+ Args:
693
+ adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
694
+ Can be a single adapter name or a list of adapter names.
695
+
696
+ Raises:
697
+ ValueError: If the model is not configured with LoRA or if the adapter name is not found.
698
+ """
699
+ if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
700
+ raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
701
+
702
+ # Convert single adapter name to list for uniform processing
703
+ if isinstance(adapter_name, str):
704
+ adapter_names = [adapter_name]
705
+ else:
706
+ adapter_names = adapter_name
707
+
708
+ # Validate that all adapter names exist
709
+ available_adapters = {
710
+ adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
711
+ }
712
+ missing_adapters = [name for name in adapter_names if name not in available_adapters]
713
+ if missing_adapters:
714
+ raise ValueError(
715
+ f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
716
+ )
717
+
718
+ # Get the adapter IDs and set them
719
+ lora_int_ids = [available_adapters[name] for name in adapter_names]
720
+ self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
721
+
670
722
  def forward(
671
723
  self,
672
724
  input_ids: Optional[torch.LongTensor] = None,
@@ -677,6 +729,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
677
729
  padded_cache_lengths: Optional[torch.Tensor] = None,
678
730
  position_ids: Optional[torch.Tensor] = None,
679
731
  token_type_ids: Optional[torch.Tensor] = None,
732
+ lora_int_ids: Optional[torch.Tensor] = None,
680
733
  return_dict: Optional[torch.Tensor] = None,
681
734
  **kwargs,
682
735
  ) -> Tuple[torch.FloatTensor]:
@@ -684,6 +737,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
684
737
  # For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
685
738
  # A for-loop ensures synchronization with the HuggingFace generate API.
686
739
  # The decoder stage operates as usual, processing inputs in batch mode.
740
+ if self.rbln_config.use_lora and lora_int_ids is None:
741
+ if self.lora_int_ids is None:
742
+ raise ValueError(
743
+ "lora_int_id is required when using LoRA. "
744
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
745
+ )
746
+ lora_int_ids = self.lora_int_ids
687
747
 
688
748
  # for only use forward
689
749
  if generate_idx is None:
@@ -708,6 +768,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
708
768
  cache_position=cache_position,
709
769
  batch_idx=b_idx,
710
770
  token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
771
+ lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
711
772
  )
712
773
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
713
774
  logits.append(output.logits)
@@ -727,6 +788,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
727
788
  inputs_embeds=inputs_embeds,
728
789
  cache_position=cache_position,
729
790
  position_ids=position_ids if self.rbln_config.use_position_ids else None,
791
+ lora_int_ids=lora_int_ids,
730
792
  ).logits
731
793
 
732
794
  if not return_dict:
@@ -14,8 +14,11 @@
14
14
  from typing import Any, Optional
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
+ from ....utils.logging import get_logger
17
18
  from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
18
- from ..siglip.configuration_siglip import RBLNSiglipVisionModelConfig
19
+
20
+
21
+ logger = get_logger(__name__)
19
22
 
20
23
 
21
24
  class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -89,8 +92,13 @@ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
89
92
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
90
93
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
91
94
 
92
- self.vision_tower = self.init_submodule_config(RBLNSiglipVisionModelConfig, vision_tower)
93
- self.language_model = self.init_submodule_config(RBLNGemma3ForCausalLMConfig, language_model)
95
+ if self.batch_size != 1:
96
+ logger.warning("Ignore batch_size for Gemma3 vision tower. It will be set to 1.")
97
+
98
+ self.vision_tower = self.initialize_submodule_config(
99
+ submodule_config=vision_tower, batch_size=1, force_kwargs=True
100
+ )
101
+ self.language_model = self.initialize_submodule_config(submodule_config=language_model)
94
102
 
95
103
  @property
96
104
  def image_prefill_chunk_size(self):
@@ -63,6 +63,7 @@ class Gemma3TextModel(DecoderOnlyModel):
63
63
  rotary_emb: torch.nn.Module = None,
64
64
  global_block_tables: Optional[torch.Tensor] = None,
65
65
  local_block_tables: Optional[torch.Tensor] = None,
66
+ lora_int_id: Optional[torch.Tensor] = None,
66
67
  ):
67
68
  # retrieve input_ids and inputs_embeds
68
69
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -105,6 +106,7 @@ class Gemma3TextModel(DecoderOnlyModel):
105
106
  cos=cos_local if is_sliding else cos_global,
106
107
  sin=sin_local if is_sliding else sin_global,
107
108
  block_tables=local_block_tables if is_sliding else global_block_tables,
109
+ lora_int_id=lora_int_id,
108
110
  )
109
111
 
110
112
  hidden_states = self.get_last_layernorm()(hidden_states)
@@ -127,12 +129,20 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
127
129
  cos: Optional[torch.Tensor] = None,
128
130
  sin: Optional[torch.Tensor] = None,
129
131
  block_tables: Optional[torch.Tensor] = None,
132
+ lora_int_id: Optional[torch.Tensor] = None,
130
133
  ):
131
134
  residual = hidden_states
132
135
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
133
136
 
134
137
  hidden_states = self.self_attn(
135
- hidden_states, attention_mask, seq_positions, past_key_values, cos, sin, block_tables
138
+ hidden_states=hidden_states,
139
+ attention_mask=attention_mask,
140
+ seq_positions=seq_positions,
141
+ past_key_values=past_key_values,
142
+ cos=cos,
143
+ sin=sin,
144
+ block_tables=block_tables,
145
+ lora_int_id=lora_int_id,
136
146
  )
137
147
  hidden_states = self.get_post_attention_layernorm()(hidden_states)
138
148
  hidden_states = residual + hidden_states
@@ -140,7 +150,7 @@ class Gemma3DecoderLayer(DecoderOnlyLayer):
140
150
  # Fully Connected
141
151
  residual = hidden_states
142
152
  hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
143
- hidden_states = self._original_mod.mlp(hidden_states)
153
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
144
154
  hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
145
155
  hidden_states = residual + hidden_states
146
156
 
@@ -17,15 +17,16 @@ import rebel
17
17
  import torch
18
18
 
19
19
  from ...modeling_outputs import RBLNDecoderOnlyOutput, RBLNGemma3ForCausalLMOutput
20
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
20
21
  from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
21
22
 
22
23
 
23
24
  class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
24
25
  def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
25
26
  super().__init__(*args, **kwargs)
26
- self.image_prefill = image_prefill # FIXME(taehoon)
27
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
28
- self.decode = self.runtime if self.phase == "decode" else None
27
+ self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
28
+ self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
29
+ self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
29
30
 
30
31
  def _prepare_prefill_inputs(self, *args, **kwargs):
31
32
  (
@@ -73,12 +74,24 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
73
74
  position_embed: Optional[torch.Tensor] = None,
74
75
  token_type_ids: Optional[torch.Tensor] = None,
75
76
  local_block_tables: Optional[torch.Tensor] = None,
77
+ lora_int_ids: Optional[torch.Tensor] = None,
76
78
  ) -> torch.FloatTensor:
77
79
  """
78
80
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
79
81
  Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
80
82
  and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
81
83
  """
84
+ if self.rbln_config.use_lora and lora_int_ids is None:
85
+ if self.lora_int_ids is None:
86
+ raise ValueError(
87
+ "lora_int_id is required when using LoRA. "
88
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
89
+ )
90
+ if batch_idx is not None:
91
+ lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
92
+ else:
93
+ lora_int_ids = self.lora_int_ids.clone()
94
+
82
95
  (
83
96
  inputs,
84
97
  cache_position,
@@ -141,6 +154,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
141
154
  query_position,
142
155
  chunked_attention_mask,
143
156
  position_ids_chunk,
157
+ lora_int_ids if self.rbln_config.use_lora else None,
144
158
  )
145
159
  else:
146
160
  logits = self.prefill(
@@ -151,6 +165,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
151
165
  query_position,
152
166
  chunked_attention_mask,
153
167
  position_ids_chunk,
168
+ lora_int_ids if self.rbln_config.use_lora else None,
154
169
  )
155
170
 
156
171
  padded_cache_lengths += current_padded_cache_lengths
@@ -173,7 +188,20 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
173
188
  position_embed: Optional[torch.Tensor] = None,
174
189
  position_ids: Optional[torch.Tensor] = None,
175
190
  local_block_tables: Optional[torch.Tensor] = None,
191
+ lora_int_ids: Optional[torch.Tensor] = None,
176
192
  ) -> torch.FloatTensor:
193
+ if self.rbln_config.use_lora and lora_int_ids is None:
194
+ if self.lora_int_ids is None:
195
+ raise ValueError(
196
+ "lora_int_id is required when using LoRA. "
197
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
198
+ )
199
+
200
+ lora_int_ids = self.lora_int_ids
201
+
202
+ if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
203
+ raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
204
+
177
205
  batch_size = inputs.shape[0]
178
206
  if batch_size != self.batch_size:
179
207
  raise RuntimeError(