optimum-rbln 0.9.2a3__py3-none-any.whl → 0.9.2a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (34) hide show
  1. optimum/rbln/__init__.py +4 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +3 -0
  4. optimum/rbln/modeling.py +71 -1
  5. optimum/rbln/transformers/__init__.py +4 -0
  6. optimum/rbln/transformers/modeling_generic.py +23 -1
  7. optimum/rbln/transformers/models/__init__.py +4 -0
  8. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +65 -1
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  11. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  12. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  13. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +33 -0
  14. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  15. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +79 -4
  16. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  17. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  18. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +9 -1
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  20. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -4
  21. optimum/rbln/transformers/models/llava/modeling_llava.py +2 -1
  22. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -1
  23. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  24. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  25. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +0 -9
  26. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -0
  27. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +2 -0
  28. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  29. optimum/rbln/transformers/models/whisper/generation_whisper.py +15 -5
  30. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  31. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/METADATA +5 -5
  32. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/RECORD +34 -32
  33. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/WHEEL +0 -0
  34. {optimum_rbln-0.9.2a3.dist-info → optimum_rbln-0.9.2a5.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -118,6 +118,8 @@ _import_structure = {
118
118
  "RBLNLlavaForConditionalGenerationConfig",
119
119
  "RBLNLlavaNextForConditionalGeneration",
120
120
  "RBLNLlavaNextForConditionalGenerationConfig",
121
+ "RBLNLoRAAdapterConfig",
122
+ "RBLNLoRAConfig",
121
123
  "RBLNMidmLMHeadModel",
122
124
  "RBLNMidmLMHeadModelConfig",
123
125
  "RBLNMistralModel",
@@ -406,6 +408,8 @@ if TYPE_CHECKING:
406
408
  RBLNLlavaForConditionalGenerationConfig,
407
409
  RBLNLlavaNextForConditionalGeneration,
408
410
  RBLNLlavaNextForConditionalGenerationConfig,
411
+ RBLNLoRAAdapterConfig,
412
+ RBLNLoRAConfig,
409
413
  RBLNMidmLMHeadModel,
410
414
  RBLNMidmLMHeadModelConfig,
411
415
  RBLNMistralForCausalLM,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.9.2a3'
32
- __version_tuple__ = version_tuple = (0, 9, 2, 'a3')
31
+ __version__ = version = '0.9.2a5'
32
+ __version_tuple__ = version_tuple = (0, 9, 2, 'a5')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -41,6 +41,9 @@ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
41
41
  class RBLNSerializableConfigProtocol(Protocol):
42
42
  def _prepare_for_serialization(self) -> Dict[str, Any]: ...
43
43
 
44
+ def __repr__(self) -> str:
45
+ return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
46
+
44
47
 
45
48
  @dataclass
46
49
  class RBLNCompileConfig:
optimum/rbln/modeling.py CHANGED
@@ -34,6 +34,49 @@ if TYPE_CHECKING:
34
34
  logger = get_logger(__name__)
35
35
 
36
36
 
37
+ def _get_dtype(
38
+ cls,
39
+ dtype: Optional[Union[str, torch.dtype, dict]],
40
+ config: PretrainedConfig,
41
+ ) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
42
+ dtype_orig = None
43
+
44
+ if dtype is not None:
45
+ if isinstance(dtype, str):
46
+ if dtype == "auto":
47
+ if hasattr(config, "dtype") and config.dtype is not None:
48
+ dtype = config.dtype
49
+ else:
50
+ dtype = torch.get_default_dtype()
51
+ elif hasattr(torch, dtype):
52
+ dtype = getattr(torch, dtype)
53
+ config.dtype = dtype
54
+ elif isinstance(dtype, torch.dtype):
55
+ config.dtype = dtype
56
+ elif isinstance(dtype, dict):
57
+ for key, curr_dtype in dtype.items():
58
+ if hasattr(config, key):
59
+ value = getattr(config, key)
60
+ curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
61
+ value.dtype = curr_dtype
62
+ # main torch dtype for modules that aren't part of any sub-config
63
+ dtype = dtype.get("")
64
+ dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
65
+ config.dtype = dtype
66
+ if dtype is None:
67
+ dtype = torch.float32
68
+ else:
69
+ raise ValueError(f"Invalid dtype: {dtype}")
70
+
71
+ dtype_orig = cls._set_default_dtype(dtype)
72
+ else:
73
+ # Use default dtype
74
+ default_dtype = torch.get_default_dtype()
75
+ config.dtype = default_dtype
76
+
77
+ return config, dtype, dtype_orig
78
+
79
+
37
80
  class RBLNModel(RBLNBaseModel):
38
81
  @classmethod
39
82
  def update_kwargs(cls, kwargs):
@@ -206,10 +249,37 @@ class RBLNModel(RBLNBaseModel):
206
249
  trust_remote_code: bool = False,
207
250
  # Some rbln-config should be applied before loading torch module (i.e. quantized llm)
208
251
  rbln_config: Optional[RBLNModelConfig] = None,
252
+ dtype: Optional[Union[str, torch.dtype, dict]] = None,
209
253
  **kwargs,
210
254
  ) -> "PreTrainedModel":
211
255
  kwargs = cls.update_kwargs(kwargs)
212
- return cls.get_hf_class().from_pretrained(
256
+
257
+ hf_class = cls.get_hf_class()
258
+
259
+ if dtype is not None:
260
+ config = hf_class.config_class.from_pretrained(
261
+ model_id,
262
+ subfolder=subfolder,
263
+ revision=revision,
264
+ cache_dir=cache_dir,
265
+ use_auth_token=use_auth_token,
266
+ local_files_only=local_files_only,
267
+ force_download=force_download,
268
+ trust_remote_code=trust_remote_code,
269
+ )
270
+
271
+ config, processed_dtype, dtype_orig = _get_dtype(
272
+ cls=hf_class,
273
+ dtype=dtype,
274
+ config=config,
275
+ )
276
+
277
+ kwargs["torch_dtype"] = processed_dtype
278
+
279
+ if dtype_orig is not None:
280
+ hf_class._set_default_dtype(dtype_orig)
281
+
282
+ return hf_class.from_pretrained(
213
283
  model_id,
214
284
  subfolder=subfolder,
215
285
  revision=revision,
@@ -110,6 +110,8 @@ _import_structure = {
110
110
  "RBLNPegasusModelConfig",
111
111
  "RBLNLlavaNextForConditionalGeneration",
112
112
  "RBLNLlavaNextForConditionalGenerationConfig",
113
+ "RBLNLoRAAdapterConfig",
114
+ "RBLNLoRAConfig",
113
115
  "RBLNMidmLMHeadModel",
114
116
  "RBLNMidmLMHeadModelConfig",
115
117
  "RBLNMistralForCausalLM",
@@ -258,6 +260,8 @@ if TYPE_CHECKING:
258
260
  RBLNLlavaForConditionalGenerationConfig,
259
261
  RBLNLlavaNextForConditionalGeneration,
260
262
  RBLNLlavaNextForConditionalGenerationConfig,
263
+ RBLNLoRAAdapterConfig,
264
+ RBLNLoRAConfig,
261
265
  RBLNMidmLMHeadModel,
262
266
  RBLNMidmLMHeadModelConfig,
263
267
  RBLNMistralForCausalLM,
@@ -23,6 +23,7 @@ different model architectures.
23
23
  import inspect
24
24
  from typing import TYPE_CHECKING, Optional, Union
25
25
 
26
+ from torch import nn
26
27
  from transformers import (
27
28
  AutoModel,
28
29
  AutoModelForAudioClassification,
@@ -57,6 +58,28 @@ class RBLNTransformerEncoder(RBLNModel):
57
58
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
58
59
  rbln_dtype = "int64"
59
60
 
61
+ @classmethod
62
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
63
+ class TransformerEncoderWrapper(nn.Module):
64
+ # Parameters to disable for RBLN compilation
65
+ DISABLED_PARAMS = {"return_dict", "use_cache"}
66
+
67
+ def __init__(self, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig):
68
+ super().__init__()
69
+ self.model = model
70
+ self.rbln_config = rbln_config
71
+ self._forward_signature = inspect.signature(model.forward)
72
+
73
+ def forward(self, *args, **kwargs):
74
+ # Disable parameters that are not compatible with RBLN compilation
75
+ for param_name in self.DISABLED_PARAMS:
76
+ if param_name in self._forward_signature.parameters:
77
+ kwargs[param_name] = False
78
+
79
+ return self.model(*args, **kwargs)
80
+
81
+ return TransformerEncoderWrapper(model, rbln_config).eval()
82
+
60
83
  @classmethod
61
84
  def _update_rbln_config(
62
85
  cls,
@@ -208,7 +231,6 @@ class RBLNModelForQuestionAnswering(RBLNTransformerEncoder):
208
231
 
209
232
  def _prepare_output(self, output, return_dict):
210
233
  # Prepare QuestionAnswering specific output format.
211
-
212
234
  start_logits, end_logits = output
213
235
 
214
236
  if not return_dict:
@@ -96,6 +96,8 @@ _import_structure = {
96
96
  "RBLNDecoderOnlyModel",
97
97
  "RBLNDecoderOnlyModelForCausalLM",
98
98
  "RBLNDecoderOnlyModelForCausalLMConfig",
99
+ "RBLNLoRAAdapterConfig",
100
+ "RBLNLoRAConfig",
99
101
  ],
100
102
  "depth_anything": ["RBLNDepthAnythingForDepthEstimationConfig", "RBLNDepthAnythingForDepthEstimation"],
101
103
  "dpt": [
@@ -239,6 +241,8 @@ if TYPE_CHECKING:
239
241
  RBLNDecoderOnlyModelConfig,
240
242
  RBLNDecoderOnlyModelForCausalLM,
241
243
  RBLNDecoderOnlyModelForCausalLMConfig,
244
+ RBLNLoRAAdapterConfig,
245
+ RBLNLoRAConfig,
242
246
  )
243
247
  from .depth_anything import RBLNDepthAnythingForDepthEstimation, RBLNDepthAnythingForDepthEstimationConfig
244
248
  from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
@@ -31,6 +31,7 @@ from transformers.utils import logging
31
31
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
32
32
  from ....modeling import RBLNModel
33
33
  from ...utils.rbln_runtime_wrapper import LoopProcessor
34
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
34
35
 
35
36
 
36
37
  logger = logging.get_logger(__name__)
@@ -265,7 +266,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
265
266
  )
266
267
 
267
268
 
268
- class RBLNBlip2ForConditionalGeneration(RBLNModel):
269
+ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
269
270
  """
270
271
  RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
271
272
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -433,3 +434,66 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
433
434
  )
434
435
 
435
436
  return inputs_embeds
437
+
438
+ @torch.no_grad()
439
+ def generate(
440
+ self,
441
+ pixel_values: torch.FloatTensor,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
445
+ interpolate_pos_encoding: bool = False,
446
+ **generate_kwargs,
447
+ ) -> torch.LongTensor:
448
+ batch_size = pixel_values.shape[0]
449
+ image_embeds = self.vision_model(
450
+ pixel_values,
451
+ return_dict=True,
452
+ interpolate_pos_encoding=interpolate_pos_encoding,
453
+ ).last_hidden_state
454
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
455
+
456
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
457
+ query_outputs = self.qformer(
458
+ query_embeds=query_tokens,
459
+ encoder_hidden_states=image_embeds,
460
+ encoder_attention_mask=image_attention_mask,
461
+ return_dict=True,
462
+ )
463
+ query_output = query_outputs.last_hidden_state
464
+
465
+ if query_output.dtype != image_embeds.dtype:
466
+ query_output = query_output.to(image_embeds.dtype)
467
+
468
+ language_model_inputs = self.language_projection(query_output)
469
+
470
+ if inputs_embeds is None:
471
+ if input_ids is None:
472
+ image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
473
+ start_tokens = image_tokens + [self.config.text_config.bos_token_id]
474
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
475
+ input_ids = input_ids.repeat(batch_size, 1)
476
+ inputs_embeds = self.get_input_embeddings()(input_ids)
477
+
478
+ if attention_mask is None:
479
+ attention_mask = torch.ones_like(input_ids)
480
+
481
+ if input_ids is None:
482
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
483
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
484
+ )
485
+ special_image_mask = special_image_mask.all(-1)
486
+ else:
487
+ special_image_mask = input_ids == self.config.image_token_id
488
+
489
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
490
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
491
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
492
+
493
+ inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
494
+ if not self.language_model.config.is_encoder_decoder:
495
+ inputs["input_ids"] = input_ids
496
+
497
+ outputs = self.language_model.generate(**inputs, **generate_kwargs)
498
+
499
+ return outputs
@@ -23,4 +23,5 @@ from ....ops import (
23
23
  paged_flash_causal_attn_prefill,
24
24
  )
25
25
  from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
26
+ from .configuration_lora import RBLNLoRAAdapterConfig, RBLNLoRAConfig
26
27
  from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
@@ -17,6 +17,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, get_args
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
19
19
  from ...utils.rbln_quantization import RBLNQuantizationConfig
20
+ from .configuration_lora import RBLNLoRAConfig
20
21
 
21
22
 
22
23
  logger = get_logger()
@@ -48,6 +49,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
48
49
  kvcache_partition_len: Optional[int] = None,
49
50
  kvcache_block_size: Optional[int] = None,
50
51
  quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
52
+ lora_config: Optional[Union[Dict[str, Any], RBLNLoRAConfig]] = None,
51
53
  prefill_chunk_size: Optional[int] = None,
52
54
  kvcache_num_blocks: Optional[int] = None,
53
55
  decoder_batch_sizes: Optional[List[int]] = None,
@@ -80,6 +82,12 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
80
82
  kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
81
83
  in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
82
84
  section below for details.
85
+ quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
86
+ quantization. Specifies format, etc.
87
+ lora_config (Optional[Union[Dict[str, Any], RBLNLoRAConfig]]): Configuration for LoRA
88
+ (Low-Rank Adaptation) settings when using (multi-)LoRA support. Can be provided as
89
+ a dictionary or an RBLNLoRAConfig instance. When provided, enables LoRA functionality
90
+ for the model compilation. Defaults to None (no LoRA).
83
91
  prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
84
92
  processing input sequences. Defaults to 128. Must be a positive integer
85
93
  divisible by 64. Affects prefill performance and memory usage.
@@ -185,6 +193,26 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
185
193
  if self.quantization and isinstance(self.quantization, dict):
186
194
  self.quantization = RBLNQuantizationConfig(**self.quantization)
187
195
 
196
+ self.lora_config = lora_config
197
+ if self.lora_config and isinstance(self.lora_config, dict):
198
+ self.lora_config = RBLNLoRAConfig(**self.lora_config)
199
+
200
+ # Validate LoRA adapters if LoRA is enabled
201
+ if self.lora_config is not None:
202
+ validation_results = self.lora_config.validate_adapter_weights()
203
+ failed_adapters = [adapter_id for adapter_id, is_valid in validation_results.items() if not is_valid]
204
+
205
+ if failed_adapters:
206
+ raise ValueError(
207
+ f"Some LoRA adapters failed validation and may not be accessible at compile time: {failed_adapters}. "
208
+ "Please ensure all adapter weights are available and properly formatted."
209
+ )
210
+
211
+ logger.info(
212
+ f"LoRA configuration initialized with {self.lora_config.num_adapters} adapters: "
213
+ f"{self.lora_config.adapter_ids}. Max rank: {self.lora_config.max_lora_rank}"
214
+ )
215
+
188
216
  self.attn_impl = attn_impl
189
217
  self.kvcache_partition_len = kvcache_partition_len
190
218
  self.kvcache_block_size = kvcache_block_size
@@ -204,6 +232,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
204
232
  if self.logits_to_keep is not None and self.logits_to_keep > 1:
205
233
  raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.")
206
234
 
235
+ self.decoder_batch_sizes = None
207
236
  if "decode" in self.phases:
208
237
  self.decoder_batch_sizes = decoder_batch_sizes
209
238
  if self.decoder_batch_sizes is None:
@@ -243,6 +272,11 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
243
272
  def use_multiple_decoder(self) -> bool:
244
273
  return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
245
274
 
275
+ @property
276
+ def use_lora(self):
277
+ """Check if LoRA is enabled for this configuration."""
278
+ return self.lora_config is not None
279
+
246
280
  @property
247
281
  def can_generate(self) -> bool:
248
282
  return "decode" in self.phases