optimum-rbln 0.9.2a3__py3-none-any.whl → 0.9.2a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (25) hide show
  1. optimum/rbln/__init__.py +4 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +3 -0
  4. optimum/rbln/transformers/__init__.py +4 -0
  5. optimum/rbln/transformers/models/__init__.py +4 -0
  6. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  7. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  8. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  9. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  10. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
  11. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +60 -0
  13. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  14. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  15. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +7 -0
  16. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  17. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  18. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  19. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
  20. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
  21. optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
  22. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a4.dist-info}/METADATA +1 -1
  23. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a4.dist-info}/RECORD +25 -23
  24. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a4.dist-info}/WHEEL +0 -0
  25. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -118,6 +118,8 @@ _import_structure = {
118
118
  "RBLNLlavaForConditionalGenerationConfig",
119
119
  "RBLNLlavaNextForConditionalGeneration",
120
120
  "RBLNLlavaNextForConditionalGenerationConfig",
121
+ "RBLNLoRAAdapterConfig",
122
+ "RBLNLoRAConfig",
121
123
  "RBLNMidmLMHeadModel",
122
124
  "RBLNMidmLMHeadModelConfig",
123
125
  "RBLNMistralModel",
@@ -406,6 +408,8 @@ if TYPE_CHECKING:
406
408
  RBLNLlavaForConditionalGenerationConfig,
407
409
  RBLNLlavaNextForConditionalGeneration,
408
410
  RBLNLlavaNextForConditionalGenerationConfig,
411
+ RBLNLoRAAdapterConfig,
412
+ RBLNLoRAConfig,
409
413
  RBLNMidmLMHeadModel,
410
414
  RBLNMidmLMHeadModelConfig,
411
415
  RBLNMistralForCausalLM,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.9.2a3'
32
- __version_tuple__ = version_tuple = (0, 9, 2, 'a3')
31
+ __version__ = version = '0.9.2a4'
32
+ __version_tuple__ = version_tuple = (0, 9, 2, 'a4')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -41,6 +41,9 @@ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
41
41
  class RBLNSerializableConfigProtocol(Protocol):
42
42
  def _prepare_for_serialization(self) -> Dict[str, Any]: ...
43
43
 
44
+ def __repr__(self) -> str:
45
+ return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
46
+
44
47
 
45
48
  @dataclass
46
49
  class RBLNCompileConfig:
@@ -110,6 +110,8 @@ _import_structure = {
110
110
  "RBLNPegasusModelConfig",
111
111
  "RBLNLlavaNextForConditionalGeneration",
112
112
  "RBLNLlavaNextForConditionalGenerationConfig",
113
+ "RBLNLoRAAdapterConfig",
114
+ "RBLNLoRAConfig",
113
115
  "RBLNMidmLMHeadModel",
114
116
  "RBLNMidmLMHeadModelConfig",
115
117
  "RBLNMistralForCausalLM",
@@ -258,6 +260,8 @@ if TYPE_CHECKING:
258
260
  RBLNLlavaForConditionalGenerationConfig,
259
261
  RBLNLlavaNextForConditionalGeneration,
260
262
  RBLNLlavaNextForConditionalGenerationConfig,
263
+ RBLNLoRAAdapterConfig,
264
+ RBLNLoRAConfig,
261
265
  RBLNMidmLMHeadModel,
262
266
  RBLNMidmLMHeadModelConfig,
263
267
  RBLNMistralForCausalLM,
@@ -96,6 +96,8 @@ _import_structure = {
96
96
  "RBLNDecoderOnlyModel",
97
97
  "RBLNDecoderOnlyModelForCausalLM",
98
98
  "RBLNDecoderOnlyModelForCausalLMConfig",
99
+ "RBLNLoRAAdapterConfig",
100
+ "RBLNLoRAConfig",
99
101
  ],
100
102
  "depth_anything": ["RBLNDepthAnythingForDepthEstimationConfig", "RBLNDepthAnythingForDepthEstimation"],
101
103
  "dpt": [
@@ -239,6 +241,8 @@ if TYPE_CHECKING:
239
241
  RBLNDecoderOnlyModelConfig,
240
242
  RBLNDecoderOnlyModelForCausalLM,
241
243
  RBLNDecoderOnlyModelForCausalLMConfig,
244
+ RBLNLoRAAdapterConfig,
245
+ RBLNLoRAConfig,
242
246
  )
243
247
  from .depth_anything import RBLNDepthAnythingForDepthEstimation, RBLNDepthAnythingForDepthEstimationConfig
244
248
  from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
@@ -23,4 +23,5 @@ from ....ops import (
23
23
  paged_flash_causal_attn_prefill,
24
24
  )
25
25
  from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
26
+ from .configuration_lora import RBLNLoRAAdapterConfig, RBLNLoRAConfig
26
27
  from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, get_args
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
19
19
  from ...utils.rbln_quantization import RBLNQuantizationConfig
20
+ from .configuration_lora import RBLNLoRAConfig
20
21
 
21
22
 
22
23
  logger = get_logger()
@@ -48,6 +49,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
48
49
  kvcache_partition_len: Optional[int] = None,
49
50
  kvcache_block_size: Optional[int] = None,
50
51
  quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
52
+ lora_config: Optional[Union[Dict[str, Any], RBLNLoRAConfig]] = None,
51
53
  prefill_chunk_size: Optional[int] = None,
52
54
  kvcache_num_blocks: Optional[int] = None,
53
55
  decoder_batch_sizes: Optional[List[int]] = None,
@@ -80,6 +82,12 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
80
82
  kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
81
83
  in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
82
84
  section below for details.
85
+ quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
86
+ quantization. Specifies format, etc.
87
+ lora_config (Optional[Union[Dict[str, Any], RBLNLoRAConfig]]): Configuration for LoRA
88
+ (Low-Rank Adaptation) settings when using (multi-)LoRA support. Can be provided as
89
+ a dictionary or an RBLNLoRAConfig instance. When provided, enables LoRA functionality
90
+ for the model compilation. Defaults to None (no LoRA).
83
91
  prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
84
92
  processing input sequences. Defaults to 128. Must be a positive integer
85
93
  divisible by 64. Affects prefill performance and memory usage.
@@ -185,6 +193,26 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
185
193
  if self.quantization and isinstance(self.quantization, dict):
186
194
  self.quantization = RBLNQuantizationConfig(**self.quantization)
187
195
 
196
+ self.lora_config = lora_config
197
+ if self.lora_config and isinstance(self.lora_config, dict):
198
+ self.lora_config = RBLNLoRAConfig(**self.lora_config)
199
+
200
+ # Validate LoRA adapters if LoRA is enabled
201
+ if self.lora_config is not None:
202
+ validation_results = self.lora_config.validate_adapter_weights()
203
+ failed_adapters = [adapter_id for adapter_id, is_valid in validation_results.items() if not is_valid]
204
+
205
+ if failed_adapters:
206
+ raise ValueError(
207
+ f"Some LoRA adapters failed validation and may not be accessible at compile time: {failed_adapters}. "
208
+ "Please ensure all adapter weights are available and properly formatted."
209
+ )
210
+
211
+ logger.info(
212
+ f"LoRA configuration initialized with {self.lora_config.num_adapters} adapters: "
213
+ f"{self.lora_config.adapter_ids}. Max rank: {self.lora_config.max_lora_rank}"
214
+ )
215
+
188
216
  self.attn_impl = attn_impl
189
217
  self.kvcache_partition_len = kvcache_partition_len
190
218
  self.kvcache_block_size = kvcache_block_size
@@ -204,6 +232,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
204
232
  if self.logits_to_keep is not None and self.logits_to_keep > 1:
205
233
  raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.")
206
234
 
235
+ self.decoder_batch_sizes = None
207
236
  if "decode" in self.phases:
208
237
  self.decoder_batch_sizes = decoder_batch_sizes
209
238
  if self.decoder_batch_sizes is None:
@@ -243,6 +272,11 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
243
272
  def use_multiple_decoder(self) -> bool:
244
273
  return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
245
274
 
275
+ @property
276
+ def use_lora(self):
277
+ """Check if LoRA is enabled for this configuration."""
278
+ return self.lora_config is not None
279
+
246
280
  @property
247
281
  def can_generate(self) -> bool:
248
282
  return "decode" in self.phases
@@ -0,0 +1,411 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ from huggingface_hub import snapshot_download
6
+
7
+ from ....configuration_utils import RBLNSerializableConfigProtocol
8
+ from ....utils.logging import get_logger
9
+
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ class RBLNLoRAAdapterConfig(RBLNSerializableConfigProtocol):
15
+ """
16
+ Configuration class for individual LoRA adapter settings.
17
+
18
+ This class represents a single LoRA adapter that will be compiled into the RBLN model.
19
+ Since RBLN NPU requires all adapters to be determined at compile time, each adapter
20
+ must be fully specified including its weights.
21
+
22
+ Examples:
23
+ ```python
24
+ from transformers import AutoTokenizer
25
+
26
+ from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig, RBLNLoRAAdapterConfig, RBLNLoRAConfig
27
+
28
+
29
+ model_id = "meta-llama/Llama-3.1-8B-Instruct"
30
+ lora_ids = [
31
+ "nvidia/llama-3.1-nemoguard-8b-topic-control",
32
+ "reissbaker/llama-3.1-8b-abliterated-lora",
33
+ ]
34
+ prompt = "What are the safety considerations for AI systems?"
35
+ tp_size = 4
36
+
37
+ # adapter id should be higher than 0
38
+ # 0 is reserved for base model
39
+ lora_config = RBLNLoRAConfig(
40
+ adapters=[
41
+ RBLNLoRAAdapterConfig(1, "nemoguard", lora_ids[0]),
42
+ RBLNLoRAAdapterConfig(2, "abliterated", lora_ids[1]),
43
+ ],
44
+ )
45
+
46
+ model = RBLNLlamaForCausalLM.from_pretrained(
47
+ model_id,
48
+ rbln_config=RBLNLlamaForCausalLMConfig(lora_config=lora_config, tensor_parallel_size=tp_size, max_seq_len=8192),
49
+ torch_dtype="auto",
50
+ )
51
+
52
+
53
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
54
+ tokenizer.pad_token = tokenizer.eos_token
55
+
56
+
57
+ prompt_template = tokenizer.apply_chat_template(
58
+ [
59
+ {"role": "system", "content": "You are a helpful assistant. Always be concise"},
60
+ {"role": "user", "content": prompt},
61
+ ],
62
+ add_generation_prompt=True,
63
+ tokenize=False,
64
+ )
65
+ inputs = tokenizer([prompt_template], return_tensors="pt")
66
+ input_len = inputs["input_ids"].shape[-1]
67
+
68
+ for adapter_name in lora_config.adapter_names:
69
+ model.set_adapter(adapter_name)
70
+ decoder_outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
71
+ generated_text = tokenizer.decode(decoder_outputs[0][input_len:], skip_special_tokens=True)
72
+ print(generated_text + "\n")
73
+ ```
74
+
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ lora_int_id: int,
80
+ lora_name: str,
81
+ lora_path: Union[str, Path],
82
+ r: Optional[int] = None,
83
+ lora_alpha: Optional[float] = None,
84
+ target_modules: Optional[List[str]] = None,
85
+ bias: Optional[str] = None,
86
+ use_rslora: Optional[bool] = None,
87
+ scaling_factor: Optional[float] = None,
88
+ ):
89
+ """
90
+ Args:
91
+ lora_int_id (int): Unique identifier for this LoRA adapter (e.g., 0, 1, 2).
92
+ This ID will be used during runtime to select which adapter to use.
93
+ lora_name (str): Human-readable name for this adapter (e.g., "math_tuned", "code_tuned").
94
+ lora_path (Union[str, Path]): Path to the LoRA adapter weights directory or file.
95
+ Must be accessible at compile time to load the weights.
96
+ r (Optional[int]): The rank of the LoRA approximation for this adapter. If None,
97
+ will be loaded from adapter config file.
98
+ lora_alpha (Optional[float]): The LoRA scaling parameter for this adapter. If None,
99
+ will be loaded from adapter config file.
100
+ target_modules (Optional[List[str]]): List of module names to apply LoRA to.
101
+ If None, will be loaded from adapter config file or inherit from parent RBLNLoRAConfig.
102
+ bias (Optional[str]): Bias handling strategy. Options: "none", "all", "lora_only".
103
+ If None, will be loaded from adapter config file.
104
+ use_rslora (Optional[bool]): Whether to use Rank-Stabilized LoRA. If None,
105
+ will be loaded from adapter config file.
106
+ scaling_factor (Optional[float]): Additional scaling factor for this adapter. Defaults to 1.0.
107
+ **kwargs: Additional adapter-specific arguments.
108
+
109
+ Raises:
110
+ ValueError: If lora_int_id is None.
111
+ ValueError: If lora_path doesn't exist.
112
+ ValueError: If r is not a positive integer.
113
+ ValueError: If lora_alpha is not positive.
114
+ """
115
+ if lora_int_id is None:
116
+ raise ValueError("lora_int_id cannot be None")
117
+
118
+ if not isinstance(lora_int_id, int):
119
+ raise ValueError(f"lora_int_id must be an integer, got {type(lora_int_id)}")
120
+
121
+ self.lora_int_id = lora_int_id
122
+ self.lora_name = lora_name
123
+
124
+ # Keep original lora_path as provided by user (for serialization)
125
+ self.lora_path = Path(lora_path)
126
+
127
+ # Resolve to local directory path (for actual weight loading)
128
+ self.local_adapter_path = self._resolve_adapter_path(self.lora_path)
129
+
130
+ # Load adapter config and use as defaults
131
+ adapter_config = self._load_adapter_config()
132
+
133
+ # Set values from adapter config if not explicitly provided
134
+ self.r = r if r is not None else adapter_config.get("r", 8)
135
+ self.lora_alpha = lora_alpha if lora_alpha is not None else adapter_config.get("lora_alpha", 8.0)
136
+ self.target_modules = (
137
+ target_modules if target_modules is not None else adapter_config.get("target_modules", None)
138
+ )
139
+ self.bias = bias if bias is not None else adapter_config.get("bias", "none")
140
+ if self.bias not in ["none"]:
141
+ raise NotImplementedError("bias != 'none' is not supported yet")
142
+
143
+ self.use_rslora = use_rslora if use_rslora is not None else adapter_config.get("use_rslora", False)
144
+ self.scaling_factor = scaling_factor if scaling_factor is not None else 1.0
145
+
146
+ # Validate the final values
147
+ if not isinstance(self.r, int) or self.r <= 0:
148
+ raise ValueError(f"r must be a positive integer, got {self.r}")
149
+
150
+ if self.lora_alpha <= 0:
151
+ raise ValueError(f"lora_alpha must be positive, got {self.lora_alpha}")
152
+
153
+ if self.bias not in ["none", "all", "lora_only"]:
154
+ raise ValueError(f"bias must be one of ['none', 'all', 'lora_only'], got {self.bias}")
155
+
156
+ def _resolve_adapter_path(self, path: Path) -> Path:
157
+ """
158
+ Resolve the adapter path, downloading from HuggingFace Hub if necessary.
159
+
160
+ Args:
161
+ path: Local path or HuggingFace Hub model ID
162
+
163
+ Returns:
164
+ Path object pointing to local adapter directory
165
+
166
+ Raises:
167
+ ValueError: If the adapter cannot be found locally or downloaded
168
+ """
169
+ # If it's a local path and exists, return it
170
+ if path.exists():
171
+ return path
172
+
173
+ # If it's an absolute path that doesn't exist, raise error
174
+ if path.is_absolute():
175
+ raise ValueError(f"LoRA adapter path does not exist: {path.as_posix()}")
176
+
177
+ # Try to interpret as HuggingFace Hub model ID and download
178
+ try:
179
+ local_dir = snapshot_download(str(path), allow_patterns=["*.safetensors", "*.bin", "*.json"])
180
+ return Path(local_dir)
181
+ except Exception as e:
182
+ raise ValueError(
183
+ f"Failed to download LoRA adapter '{path.as_posix()}' from HuggingFace Hub. "
184
+ f"Please check if the model ID is correct or provide a valid local path. "
185
+ f"Error: {e}"
186
+ )
187
+
188
+ def _load_adapter_config(self) -> Dict[str, Any]:
189
+ """
190
+ Load adapter configuration from adapter_config.json file.
191
+
192
+ Returns:
193
+ Dictionary containing adapter configuration
194
+
195
+ Raises:
196
+ ValueError: If adapter_config.json is not found or cannot be parsed
197
+ """
198
+ config_path = self.local_adapter_path / "adapter_config.json"
199
+
200
+ if not config_path.exists():
201
+ logger.warning(f"No adapter_config.json found at {config_path}, using default values")
202
+ return {}
203
+
204
+ try:
205
+ with open(config_path, "r", encoding="utf-8") as f:
206
+ adapter_config = json.load(f)
207
+ logger.info(f"Loaded adapter config from {config_path}")
208
+ return adapter_config
209
+ except Exception as e:
210
+ logger.warning(f"Failed to load adapter config from {config_path}: {e}, using default values")
211
+ return {}
212
+
213
+ def _prepare_for_serialization(self) -> Dict[str, Any]:
214
+ config_dict = {
215
+ "lora_int_id": self.lora_int_id,
216
+ "lora_name": self.lora_name,
217
+ "lora_path": str(self.lora_path),
218
+ "r": self.r,
219
+ "lora_alpha": self.lora_alpha,
220
+ "target_modules": self.target_modules,
221
+ "bias": self.bias,
222
+ "use_rslora": self.use_rslora,
223
+ "scaling_factor": self.scaling_factor,
224
+ }
225
+ return config_dict
226
+
227
+
228
+ class RBLNLoRABaseAdapterConfig(RBLNLoRAAdapterConfig):
229
+ """
230
+ Special adapter config for the reserved base model adapter (lora_int_id = 0).
231
+ This adapter carries zero-effective LoRA weights by targeting no modules,
232
+ thereby producing no LoRA delta and yielding pure base-model behavior.
233
+ """
234
+
235
+ def __init__(
236
+ self,
237
+ lora_int_id: int = 0,
238
+ lora_name: str = "base",
239
+ lora_path: Union[str, Path] = "__reserved_base__",
240
+ r: Optional[int] = 1,
241
+ lora_alpha: Optional[float] = 1.0,
242
+ target_modules: Optional[List[str]] = None,
243
+ bias: Optional[str] = "none",
244
+ use_rslora: Optional[bool] = False,
245
+ scaling_factor: Optional[float] = 1.0,
246
+ ):
247
+ if lora_int_id != 0:
248
+ raise ValueError("RBLNLoRABaseAdapterConfig must have lora_int_id=0")
249
+
250
+ self.lora_int_id = 0
251
+ self.lora_name = lora_name
252
+ # Keep original lora_path for serialization purposes but do not resolve it.
253
+ self.lora_path = Path(str(lora_path))
254
+ self.local_adapter_path = None
255
+
256
+ # Set minimal defaults; target_modules empty disables LoRA on all projections
257
+ self.r = 1 if r is None else r
258
+ self.lora_alpha = 1.0 if lora_alpha is None else lora_alpha
259
+ self.target_modules = []
260
+ self.bias = "none"
261
+ self.use_rslora = False
262
+ self.scaling_factor = 1.0
263
+
264
+ # Validate minimal settings
265
+ if not isinstance(self.r, int) or self.r <= 0:
266
+ raise ValueError(f"r must be a positive integer, got {self.r}")
267
+ if self.lora_alpha <= 0:
268
+ raise ValueError(f"lora_alpha must be positive, got {self.lora_alpha}")
269
+
270
+
271
+ class RBLNLoRAConfig(RBLNSerializableConfigProtocol):
272
+ """
273
+ Configuration class for multi-LoRA support in RBLN decoder-only models.
274
+
275
+ This class manages all LoRA adapters that will be compiled into the RBLN model.
276
+ Since RBLN NPU requires all adapters to be determined at compile time, this
277
+ configuration must specify all adapters upfront with their weights.
278
+
279
+ Key constraints for RBLN multi-LoRA:
280
+ 1. All LoRA adapters must be specified at compile time
281
+ 2. Adapter weights must be available during compilation
282
+ 3. The number of adapters is fixed after compilation
283
+ 4. Runtime can only switch between pre-compiled adapters
284
+ """
285
+
286
+ def __init__(
287
+ self, adapters: List[Union[Dict[str, Any], RBLNLoRAAdapterConfig]], max_lora_rank: Optional[int] = None
288
+ ):
289
+ """
290
+ Args:
291
+ adapters (List[Union[Dict[str, Any], RBLNLoRAAdapterConfig]]): List of LoRA adapters
292
+ to be compiled into the model. Each adapter must be fully specified with weights
293
+ accessible at compile time.
294
+ max_lora_rank (Optional[int]): Maximum rank across all adapters. If None, automatically
295
+ determined from the provided adapters. Used for memory allocation optimization.
296
+
297
+ Raises:
298
+ ValueError: If adapters list is empty.
299
+ ValueError: If adapter IDs are not unique.
300
+ ValueError: If any adapter path doesn't exist.
301
+ """
302
+ if not adapters:
303
+ raise ValueError("adapters list cannot be empty")
304
+
305
+ # Convert dict adapters to RBLNLoRAAdapterConfig objects
306
+ self.adapters: List[RBLNLoRAAdapterConfig] = []
307
+ for adapter in adapters:
308
+ if isinstance(adapter, dict):
309
+ self.adapters.append(RBLNLoRAAdapterConfig(**adapter))
310
+ elif isinstance(adapter, RBLNLoRAAdapterConfig):
311
+ self.adapters.append(adapter)
312
+ else:
313
+ raise ValueError(f"Invalid adapter type: {type(adapter)}")
314
+
315
+ # Disallow user-provided adapter with id 0: it's reserved for base model
316
+ if any(ad.lora_int_id == 0 for ad in self.adapters):
317
+ raise ValueError(
318
+ "lora_int_id=0 is reserved for base model and cannot be provided. Please renumber your adapters to start from 1."
319
+ )
320
+
321
+ # Inject a reserved zero-weight adapter for base model at id=0
322
+ base_adapter = RBLNLoRABaseAdapterConfig()
323
+ self.adapters.insert(0, base_adapter)
324
+
325
+ # Sort adapters by ID to make IDs align with indices
326
+ self.adapters.sort(key=lambda a: a.lora_int_id)
327
+
328
+ # Validate unique and contiguous adapter IDs starting from 0
329
+ adapter_ids = [adapter.lora_int_id for adapter in self.adapters]
330
+ if len(adapter_ids) != len(set(adapter_ids)):
331
+ raise ValueError("All adapter IDs must be unique")
332
+ expected_ids = list(range(len(self.adapters)))
333
+ if adapter_ids != expected_ids:
334
+ raise ValueError(
335
+ f"Adapter IDs must be contiguous and start from 0. Found {adapter_ids}, expected {expected_ids}."
336
+ )
337
+
338
+ # Calculate max_lora_rank if not provided
339
+ if max_lora_rank is None:
340
+ self.max_lora_rank = max(adapter.r for adapter in self.adapters)
341
+ else:
342
+ self.max_lora_rank = max_lora_rank
343
+ # Validate that max_lora_rank is sufficient
344
+ actual_max_rank = max(adapter.r for adapter in self.adapters)
345
+ if self.max_lora_rank < actual_max_rank:
346
+ raise ValueError(
347
+ f"max_lora_rank ({self.max_lora_rank}) must be >= actual max rank ({actual_max_rank})"
348
+ )
349
+
350
+ @property
351
+ def num_adapters(self) -> int:
352
+ return len(self.adapters)
353
+
354
+ @property
355
+ def adapter_ids(self) -> List[int]:
356
+ return [adapter.lora_int_id for adapter in self.adapters]
357
+
358
+ @property
359
+ def adapter_names(self) -> List[str]:
360
+ return [adapter.lora_name for adapter in self.adapters]
361
+
362
+ def get_adapter_by_id(self, lora_int_id: int) -> Optional[RBLNLoRAAdapterConfig]:
363
+ for adapter in self.adapters:
364
+ if adapter.lora_int_id == lora_int_id:
365
+ return adapter
366
+ return None
367
+
368
+ def get_adapter_by_name(self, lora_name: str) -> Optional[RBLNLoRAAdapterConfig]:
369
+ for adapter in self.adapters:
370
+ if adapter.lora_name == lora_name:
371
+ return adapter
372
+ return None
373
+
374
+ def validate_adapter_weights(self) -> Dict[int, bool]:
375
+ validation_results = {}
376
+ for adapter in self.adapters:
377
+ try:
378
+ # The reserved base adapter (id=0) always validates to True
379
+ if adapter.lora_int_id == 0:
380
+ validation_results[adapter.lora_int_id] = True
381
+ continue
382
+ # Check if adapter path exists and contains expected files
383
+ adapter_path = adapter.local_adapter_path
384
+ if adapter_path is not None and adapter_path.is_file():
385
+ # Single file adapter (e.g., safetensors)
386
+ validation_results[adapter.lora_int_id] = adapter_path.exists()
387
+ else:
388
+ # Directory adapter - check for common LoRA files
389
+ expected_files = ["adapter_model.safetensors", "adapter_config.json"]
390
+ alternative_files = ["pytorch_model.bin", "adapter_model.bin"]
391
+
392
+ has_weights = adapter_path is not None and any(
393
+ (adapter_path / f).exists() for f in expected_files + alternative_files
394
+ )
395
+ has_config = adapter_path is not None and (adapter_path / "adapter_config.json").exists()
396
+
397
+ validation_results[adapter.lora_int_id] = has_weights and has_config
398
+ except Exception as e:
399
+ logger.warning(f"Failed to validate adapter {adapter.lora_int_id}: {e}")
400
+ validation_results[adapter.lora_int_id] = False
401
+
402
+ return validation_results
403
+
404
+ def _prepare_for_serialization(self) -> Dict[str, Any]:
405
+ # Do not serialize the reserved base adapter (id=0)
406
+ serializable_adapters = [adapter for adapter in self.adapters if adapter.lora_int_id != 0]
407
+ serializable_map = {
408
+ "adapters": [adapter._prepare_for_serialization() for adapter in serializable_adapters],
409
+ "max_lora_rank": self.max_lora_rank,
410
+ }
411
+ return serializable_map