optimum-rbln 0.8.4a8__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +63 -32
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +30 -14
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +11 -8
  7. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +23 -13
  8. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +10 -6
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +14 -10
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +14 -7
  11. optimum/rbln/diffusers/modeling_diffusers.py +5 -7
  12. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +9 -11
  13. optimum/rbln/modeling.py +50 -0
  14. optimum/rbln/modeling_base.py +1 -2
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_generic.py +37 -1
  17. optimum/rbln/transformers/models/__init__.py +9 -0
  18. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +35 -3
  19. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +86 -23
  20. optimum/rbln/transformers/models/clip/modeling_clip.py +4 -0
  21. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  22. optimum/rbln/transformers/models/colpali/configuration_colpali.py +34 -18
  23. optimum/rbln/transformers/models/colpali/modeling_colpali.py +73 -80
  24. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  25. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  26. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  27. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  28. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  29. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  30. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  32. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +50 -2
  33. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  34. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +65 -3
  35. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +11 -3
  36. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  37. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  38. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +67 -44
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +27 -3
  41. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +24 -19
  42. optimum/rbln/transformers/models/llava/configuration_llava.py +16 -2
  43. optimum/rbln/transformers/models/llava/modeling_llava.py +108 -50
  44. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +11 -13
  45. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -343
  46. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  47. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +6 -11
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +9 -8
  50. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +24 -0
  51. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +24 -0
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  54. optimum/rbln/transformers/models/siglip/modeling_siglip.py +3 -14
  55. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  57. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  58. optimum/rbln/utils/runtime_utils.py +25 -15
  59. optimum/rbln/utils/submodule.py +21 -5
  60. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/METADATA +7 -6
  61. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/RECORD +64 -55
  62. optimum_rbln-0.9.2.dist-info/entry_points.txt +2 -0
  63. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/licenses/LICENSE +0 -0
@@ -123,7 +123,10 @@ class MidmAttention(DecoderOnlyAttention):
123
123
  self.split_size = self._original_mod.split_size
124
124
  self.num_key_value_heads = self._original_mod.num_heads
125
125
 
126
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
+ def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
127
+ if lora_int_id is not None:
128
+ raise NotImplementedError("LoRA is not supported for MidmAttention")
129
+
127
130
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
128
131
  return query_states, key_states, value_states
129
132
 
@@ -56,7 +56,10 @@ class PhiAttention(DecoderOnlyAttention):
56
56
  self.qk_layernorm = self._original_mod.qk_layernorm
57
57
  self.rotary_ndims = self._original_mod.rotary_ndims
58
58
 
59
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
59
+ def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
+ if lora_int_id is not None:
61
+ raise NotImplementedError("LoRA is not supported for PhiAttention")
62
+
60
63
  query_states = self.q_proj(hidden_states)
61
64
  key_states = self.k_proj(hidden_states)
62
65
  value_states = self.v_proj(hidden_states)
@@ -84,6 +87,7 @@ class PhiLayer(DecoderOnlyLayer):
84
87
  cos: Optional[torch.Tensor] = None,
85
88
  sin: Optional[torch.Tensor] = None,
86
89
  block_tables: Optional[torch.Tensor] = None,
90
+ lora_int_id: Optional[torch.Tensor] = None,
87
91
  ):
88
92
  residual = hidden_states
89
93
 
@@ -36,8 +36,6 @@ logger = get_logger(__name__)
36
36
  if TYPE_CHECKING:
37
37
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
38
38
 
39
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
40
-
41
39
 
42
40
  class RBLNRuntimePixtralVisionModel(RBLNPytorchRuntime):
43
41
  mandatory_members = ["main_input_name"]
@@ -128,8 +126,11 @@ class RBLNRuntimePixtralVisionModel(RBLNPytorchRuntime):
128
126
  (1, patch_embed_seq.shape[-2]), fill_value=torch.finfo(patch_embed_seq.dtype).min
129
127
  )
130
128
  attention_mask[:, : h_patched_original * w_patched_original] = 0
131
-
132
- transformer_output = super().forward(patch_embed_seq, attention_mask, cos, sin)
129
+ if "out" in kwargs:
130
+ super().forward(patch_embed_seq, attention_mask, cos, sin, **kwargs)
131
+ transformer_output = kwargs["out"]
132
+ else:
133
+ transformer_output = super().forward(patch_embed_seq, attention_mask, cos, sin, **kwargs)
133
134
 
134
135
  last_hidden_state_list.append(transformer_output[0][:, : h_patched_original * w_patched_original, :])
135
136
  hidden_states = transformer_output[1:]
@@ -236,12 +237,6 @@ class RBLNPixtralVisionModel(RBLNModel):
236
237
  }
237
238
  return _PixtralVisionModel(model, **wrapper_cfg).eval()
238
239
 
239
- @classmethod
240
- def update_rbln_config_using_pipe(
241
- cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
242
- ) -> "RBLNDiffusionMixinConfig":
243
- return rbln_config
244
-
245
240
  @classmethod
246
241
  def _update_rbln_config(
247
242
  cls,
@@ -309,7 +304,7 @@ class RBLNPixtralVisionModel(RBLNModel):
309
304
  )
310
305
 
311
306
  output = self.model(
312
- pixel_values, image_sizes, output_hidden_states=output_hidden_states, return_dict=return_dict
307
+ pixel_values, image_sizes, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs
313
308
  )
314
309
 
315
310
  return output
@@ -23,6 +23,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
23
23
  Qwen2_5_VisionPatchEmbed,
24
24
  Qwen2_5_VisionRotaryEmbedding,
25
25
  Qwen2_5_VisionTransformerPretrainedModel,
26
+ Qwen2_5_VLModel,
26
27
  Qwen2_5_VLRotaryEmbedding,
27
28
  )
28
29
 
@@ -392,13 +393,12 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
392
393
  return True
393
394
 
394
395
  @classmethod
395
- def update_kwargs(cls, kwargs):
396
- kwargs.update(
397
- {
398
- "_attn_implementation": "eager",
399
- }
400
- )
401
- return super().update_kwargs(kwargs)
396
+ def get_pytorch_model(cls, *args, **kwargs):
397
+ model = super().get_pytorch_model(*args, **kwargs)
398
+ model.model.lm_head = model.lm_head
399
+ model.lm_head = None
400
+ del model.lm_head
401
+ return model
402
402
 
403
403
  @classmethod
404
404
  def get_input_info(
@@ -532,7 +532,8 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
532
532
  vision_tokens = input_id[0][vision_start_indices + 1]
533
533
  image_nums = (vision_tokens == image_token_id).sum()
534
534
  video_nums = (vision_tokens == video_token_id).sum()
535
- position_ids, rope_deltas = self.get_rope_index(
535
+ position_ids, rope_deltas = Qwen2_5_VLModel.get_rope_index(
536
+ self,
536
537
  input_id,
537
538
  image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
538
539
  video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
@@ -3,6 +3,7 @@ from typing import Tuple
3
3
 
4
4
  import torch
5
5
  import torch.nn as nn
6
+ from transformers import PreTrainedModel
6
7
 
7
8
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper, apply_rotary_pos_emb
8
9
 
@@ -164,6 +165,7 @@ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
164
165
  position_embeds = args.pop(0)
165
166
  query_position = args.pop(0) if self.phase == "prefill" else None
166
167
  position_ids = None
168
+ lora_int_id = None
167
169
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
168
170
  past_key_values = args
169
171
 
@@ -191,6 +193,28 @@ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
191
193
  query_position,
192
194
  attention_mask,
193
195
  position_ids,
196
+ lora_int_id,
194
197
  past_key_values,
195
198
  position_embeds,
196
199
  )
200
+
201
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
202
+ new_layers = []
203
+
204
+ for layer_idx, layer in enumerate(model.model.language_model.layers):
205
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
206
+ new_self_attn = self.get_rbln_attn_class()(
207
+ self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
208
+ )
209
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
210
+ new_layers.append(new_layer)
211
+
212
+ new_model = self.get_rbln_model_class()(
213
+ model.model.language_model,
214
+ new_layers,
215
+ self.rbln_config,
216
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
217
+ )
218
+
219
+ new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
220
+ return new_model
@@ -27,6 +27,7 @@ from transformers.modeling_utils import no_init_weights
27
27
  from transformers.models.qwen2_vl.modeling_qwen2_vl import (
28
28
  PatchEmbed,
29
29
  Qwen2VisionTransformerPretrainedModel,
30
+ Qwen2VLModel,
30
31
  Qwen2VLRotaryEmbedding,
31
32
  VisionRotaryEmbedding,
32
33
  )
@@ -280,6 +281,14 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
280
281
  def can_generate(self):
281
282
  return True
282
283
 
284
+ @classmethod
285
+ def get_pytorch_model(cls, *args, **kwargs):
286
+ model = super().get_pytorch_model(*args, **kwargs)
287
+ model.model.lm_head = model.lm_head
288
+ model.lm_head = None
289
+ del model.lm_head
290
+ return model
291
+
283
292
  @classmethod
284
293
  def get_input_info(
285
294
  cls,
@@ -402,7 +411,8 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
402
411
  vision_tokens = input_id[0][vision_start_indices + 1]
403
412
  image_nums = (vision_tokens == image_token_id).sum()
404
413
  video_nums = (vision_tokens == video_token_id).sum()
405
- position_ids, rope_deltas = self.get_rope_index(
414
+ position_ids, rope_deltas = Qwen2VLModel.get_rope_index(
415
+ self,
406
416
  input_id,
407
417
  image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
408
418
  video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
@@ -3,6 +3,7 @@ from typing import Tuple
3
3
 
4
4
  import torch
5
5
  import torch.nn as nn
6
+ from transformers import PreTrainedModel
6
7
 
7
8
  from ..decoderonly.decoderonly_architecture import (
8
9
  DecoderOnlyWrapper,
@@ -110,6 +111,7 @@ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
110
111
  query_position = args.pop(0) if self.phase == "prefill" else None
111
112
  position_ids = None
112
113
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
114
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
113
115
  past_key_values = args
114
116
 
115
117
  if len(past_key_values) != 2 * self.num_hidden_layers:
@@ -136,6 +138,28 @@ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
136
138
  query_position,
137
139
  attention_mask,
138
140
  position_ids,
141
+ lora_int_id,
139
142
  past_key_values,
140
143
  position_embeds,
141
144
  )
145
+
146
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
147
+ new_layers = []
148
+
149
+ for layer_idx, layer in enumerate(model.model.language_model.layers):
150
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
151
+ new_self_attn = self.get_rbln_attn_class()(
152
+ self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
153
+ )
154
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
155
+ new_layers.append(new_layer)
156
+
157
+ new_model = self.get_rbln_model_class()(
158
+ model.model.language_model,
159
+ new_layers,
160
+ self.rbln_config,
161
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
162
+ )
163
+
164
+ new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
165
+ return new_model
@@ -20,6 +20,7 @@ import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
+ from transformers.generation.utils import GenerationMixin
23
24
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
24
25
 
25
26
  from ....configuration_utils import RBLNCompileConfig
@@ -101,7 +102,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
101
102
  return Seq2SeqLMOutput(logits=lm_logits)
102
103
 
103
104
 
104
- class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
105
+ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
105
106
  """
106
107
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
107
108
  This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -117,6 +118,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
117
118
  main_input_name = "input_ids"
118
119
  auto_model_class = AutoModelForSeq2SeqLM
119
120
  support_causal_attn = None
121
+ _is_stateful = False
120
122
 
121
123
  def __post_init__(self, **kwargs):
122
124
  batch_size = self.rbln_config.batch_size
@@ -29,8 +29,6 @@ logger = get_logger(__name__)
29
29
  if TYPE_CHECKING:
30
30
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
31
31
 
32
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
33
-
34
32
 
35
33
  class _SiglipVisionModel(torch.nn.Module):
36
34
  def __init__(
@@ -65,6 +63,8 @@ class RBLNSiglipVisionModel(RBLNModel):
65
63
  on RBLN devices, supporting image encoding for multimodal vision-language tasks.
66
64
  """
67
65
 
66
+ _tp_support = False
67
+
68
68
  @classmethod
69
69
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
70
70
  wrapper_cfg = {
@@ -74,12 +74,6 @@ class RBLNSiglipVisionModel(RBLNModel):
74
74
  }
75
75
  return _SiglipVisionModel(model, **wrapper_cfg).eval()
76
76
 
77
- @classmethod
78
- def update_rbln_config_using_pipe(
79
- cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
80
- ) -> "RBLNDiffusionMixinConfig":
81
- return rbln_config
82
-
83
77
  @classmethod
84
78
  def _update_rbln_config(
85
79
  cls,
@@ -128,11 +122,6 @@ class RBLNSiglipVisionModel(RBLNModel):
128
122
  interpolate_pos_encoding: bool = False,
129
123
  **kwargs: Any,
130
124
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
131
- if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
132
- logger.warning(
133
- f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
134
- )
135
-
136
125
  output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
137
126
  output_hidden_states = (
138
127
  output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
@@ -156,7 +145,7 @@ class RBLNSiglipVisionModel(RBLNModel):
156
145
  f"Please compile again with the correct argument."
157
146
  )
158
147
 
159
- output = super().forward(pixel_values, return_dict=return_dict)
148
+ output = super().forward(pixel_values, return_dict=return_dict, **kwargs)
160
149
  return output
161
150
 
162
151
  def _prepare_output(self, output, return_dict):
@@ -39,14 +39,31 @@ from transformers.models.whisper.generation_whisper import WhisperGenerationMixi
39
39
 
40
40
 
41
41
  class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
42
- """
43
- This class is based on transformers version 4.44.2.
44
- It uses the same generate() method, so it's crucial to maintain the inheritance order.
45
- Ensure WhisperGenerationMixin is listed before GenerationMixin.
46
- """
42
+ def generate(self, *args, generation_config=None, **kwargs):
43
+ num_beams = kwargs.get(
44
+ "num_beams",
45
+ generation_config.num_beams
46
+ if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
47
+ else 1,
48
+ )
49
+ if num_beams > 1:
50
+ raise ValueError(
51
+ f"Beam search is not supported in RBLNWhisperGenerationMixin. "
52
+ f"Received num_beams={num_beams}, but only num_beams=1 is allowed. "
53
+ f"Please set num_beams=1 for greedy search or adjust your configuration."
54
+ )
55
+
56
+ return super().generate(*args, **kwargs)
47
57
 
48
58
  def _postprocess_outputs(
49
- self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, *args, **kwargs
59
+ self,
60
+ seek_outputs,
61
+ decoder_input_ids,
62
+ return_token_timestamps,
63
+ generation_config,
64
+ is_shortform,
65
+ seek,
66
+ batch_idx_map,
50
67
  ):
51
68
  # remove all previously passed decoder input ids
52
69
  # should happen only if it is the first generated segment
@@ -64,6 +81,11 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
64
81
 
65
82
  if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
66
83
  num_frames = getattr(generation_config, "num_frames", None)
84
+
85
+ if num_frames is not None:
86
+ num_frames = num_frames - seek
87
+ num_frames = num_frames[batch_idx_map]
88
+
67
89
  if version.parse(transformers.__version__) >= version.parse("4.46.0"):
68
90
  seek_outputs["token_timestamps"] = self._extract_token_timestamps(
69
91
  seek_outputs,
@@ -150,7 +150,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
150
150
  """
151
151
 
152
152
  auto_model_class = AutoModelForSpeechSeq2Seq
153
- main_input_name = "input_ids"
153
+ main_input_name = "input_features"
154
+ _is_stateful = False
154
155
 
155
156
  def __post_init__(self, **kwargs):
156
157
  super().__post_init__(**kwargs)
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
18
+
19
+ from torch.nn import Module
20
+
21
+ from ...modeling import RBLNModel
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ import rebel
26
+
27
+
28
+ class LoopProcessor(Module, ABC):
29
+ def __init__(self, model: Union[RBLNModel, "rebel.Runtime"]):
30
+ super().__init__()
31
+ self.model = model
32
+
33
+ def __repr__(self) -> str:
34
+ return repr(self.model)
35
+
36
+ def _is_batch_implemented(self) -> bool:
37
+ return self._forward_batch.__func__ is not LoopProcessor._forward_batch
38
+
39
+ def forward(self, *args, force_loop: bool = False, **kwargs) -> Any:
40
+ if not force_loop and self._is_batch_implemented():
41
+ return self._forward_batch(*args, **kwargs)
42
+ else:
43
+ return self._forward_loop(*args, **kwargs)
44
+
45
+ def _forward_loop(self, *args, **kwargs) -> Any:
46
+ batch_size = self._get_batch_size(*args, **kwargs)
47
+
48
+ if not isinstance(batch_size, int) or batch_size == 0:
49
+ return self._process_outputs([])
50
+
51
+ common_inputs = self._prepare_inputs_before_loop(*args, **kwargs)
52
+
53
+ outputs = []
54
+ for i in range(batch_size):
55
+ item_args, item_kwargs = self._prepare_inputs_for_iteration(i, common_inputs, *args, **kwargs)
56
+ item_output = self.model(*item_args, **item_kwargs)
57
+ outputs.append(item_output)
58
+
59
+ return self._process_outputs(outputs, **kwargs)
60
+
61
+ def _forward_batch(self, *args, **kwargs) -> Any:
62
+ raise NotImplementedError("The batch processing logic (_forward_batch) is not implemented in this class.")
63
+
64
+ @abstractmethod
65
+ def _get_batch_size(self, *args, **kwargs) -> int:
66
+ pass
67
+
68
+ @abstractmethod
69
+ def _prepare_inputs_for_iteration(
70
+ self, index: int, common_inputs: Dict[str, Any], *args, **kwargs
71
+ ) -> Tuple[List[Any], Dict[str, Any]]:
72
+ pass
73
+
74
+ def _prepare_inputs_before_loop(self, *args, **kwargs) -> Dict[str, Any]:
75
+ pass
76
+
77
+ @abstractmethod
78
+ def _process_outputs(self, outputs: List[Any], **kwargs) -> Any:
79
+ pass
@@ -167,33 +167,44 @@ class ContextRblnConfig:
167
167
  device=None,
168
168
  device_map=None,
169
169
  create_runtimes=None,
170
- optimize_host_mem=None,
171
170
  activate_profiler=None,
172
171
  timeout=None,
173
172
  ):
174
173
  self.device = device
175
174
  self.device_map = device_map
176
175
  self.create_runtimes = create_runtimes
177
- self.optimize_host_mem = optimize_host_mem
178
176
  self.activate_profiler = activate_profiler
179
177
  self.timeout = timeout
178
+ self._previous_context = None
180
179
 
181
180
  def __enter__(self):
182
- self._local.device = self.device
183
- self._local.device_map = self.device_map
184
- self._local.create_runtimes = self.create_runtimes
185
- self._local.optimize_host_memory = self.optimize_host_mem
186
- self._local.activate_profiler = self.activate_profiler
187
- self._local.timeout = self.timeout
181
+ self._previous_context = {
182
+ "device": getattr(self._local, "device", None),
183
+ "device_map": getattr(self._local, "device_map", None),
184
+ "create_runtimes": getattr(self._local, "create_runtimes", None),
185
+ "activate_profiler": getattr(self._local, "activate_profiler", None),
186
+ "timeout": getattr(self._local, "timeout", None),
187
+ }
188
+
189
+ if self.device is not None:
190
+ self._local.device = self.device
191
+ if self.device_map is not None:
192
+ self._local.device_map = self.device_map
193
+ if self.create_runtimes is not None:
194
+ self._local.create_runtimes = self.create_runtimes
195
+ if self.activate_profiler is not None:
196
+ self._local.activate_profiler = self.activate_profiler
197
+ if self.timeout is not None:
198
+ self._local.timeout = self.timeout
188
199
  return self
189
200
 
190
201
  def __exit__(self, exc_type, exc_val, exc_tb):
191
- self._local.device = None
192
- self._local.device_map = None
193
- self._local.create_runtimes = None
194
- self._local.optimize_host_memory = None
195
- self._local.activate_profiler = None
196
- self._local.timeout = None
202
+ if self._previous_context is not None:
203
+ self._local.device = self._previous_context["device"]
204
+ self._local.device_map = self._previous_context["device_map"]
205
+ self._local.create_runtimes = self._previous_context["create_runtimes"]
206
+ self._local.activate_profiler = self._previous_context["activate_profiler"]
207
+ self._local.timeout = self._previous_context["timeout"]
197
208
 
198
209
  @classmethod
199
210
  def get_current_context(cls):
@@ -201,7 +212,6 @@ class ContextRblnConfig:
201
212
  "device": getattr(cls._local, "device", None),
202
213
  "device_map": getattr(cls._local, "device_map", None),
203
214
  "create_runtimes": getattr(cls._local, "create_runtimes", None),
204
- "optimize_host_memory": getattr(cls._local, "optimize_host_memory", None),
205
215
  "activate_profiler": getattr(cls._local, "activate_profiler", None),
206
216
  "timeout": getattr(cls._local, "timeout", None),
207
217
  }
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
17
17
 
18
18
  from transformers import PretrainedConfig
19
19
 
20
- from ..configuration_utils import RBLNModelConfig
20
+ from ..configuration_utils import RBLNModelConfig, get_rbln_config_class
21
21
  from ..utils.model_utils import get_rbln_model_cls
22
22
 
23
23
 
@@ -41,6 +41,15 @@ class SubModulesMixin:
41
41
  for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
42
42
  setattr(self, submodule_meta["name"], submodule)
43
43
 
44
+ @classmethod
45
+ def _get_submodule_config_class(
46
+ cls, cls_name: str, submodule_rbln_config: Dict[str, Any]
47
+ ) -> Type[RBLNModelConfig]:
48
+ if isinstance(submodule_rbln_config, dict) and "cls_name" in submodule_rbln_config:
49
+ config_cls_name = submodule_rbln_config["cls_name"]
50
+ return get_rbln_config_class(config_cls_name)
51
+ return get_rbln_config_class(f"RBLN{cls_name}Config")
52
+
44
53
  @classmethod
45
54
  def _update_submodule_config(
46
55
  cls,
@@ -69,12 +78,19 @@ class SubModulesMixin:
69
78
  cls_name = torch_submodule.__class__.__name__
70
79
  submodule_cls: Type["RBLNModel"] = get_rbln_model_cls(f"RBLN{cls_name}")
71
80
  submodule_rbln_config = getattr(rbln_config, submodule_name) or {}
81
+ submodule_config_cls = cls._get_submodule_config_class(cls_name, submodule_rbln_config)
72
82
 
73
83
  if isinstance(submodule_rbln_config, dict):
74
- submodule_rbln_config_class = submodule_cls.get_rbln_config_class()
75
- submodule_rbln_config = submodule_rbln_config_class(**submodule_rbln_config)
76
- setattr(rbln_config, submodule_name, submodule_rbln_config)
77
-
84
+ filtered_kwargs = rbln_config.filter_parameters(submodule_config_cls, submodule_rbln_config)
85
+ filtered_kwargs["cls_name"] = submodule_config_cls.__name__
86
+ submodule_rbln_config = submodule_config_cls(**filtered_kwargs)
87
+ elif not isinstance(submodule_rbln_config, submodule_config_cls):
88
+ config_dict = {k: v for k, v in submodule_rbln_config.__dict__.items() if not k.startswith("_")}
89
+ filtered_kwargs = rbln_config.filter_parameters(submodule_config_cls, config_dict)
90
+ filtered_kwargs["cls_name"] = submodule_config_cls.__name__
91
+ submodule_rbln_config = submodule_config_cls(**filtered_kwargs)
92
+
93
+ setattr(rbln_config, submodule_name, submodule_rbln_config)
78
94
  submodule_rbln_config = submodule_cls._update_submodule_config(model, submodule_rbln_config, preprocessors)
79
95
 
80
96
  rbln_submodule = submodule_cls.from_model(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.8.4a8
3
+ Version: 0.9.2
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -20,15 +20,16 @@ Classifier: Programming Language :: Python :: 3.9
20
20
  Classifier: Programming Language :: Python :: 3.10
21
21
  Classifier: Programming Language :: Python :: 3.11
22
22
  Classifier: Programming Language :: Python :: 3.12
23
+ Classifier: Programming Language :: Python :: 3.13
23
24
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
24
- Requires-Python: <3.13,>=3.9
25
+ Requires-Python: <3.14,>=3.9
25
26
  Requires-Dist: accelerate>=1.0.1
26
27
  Requires-Dist: diffusers==0.35.1
27
28
  Requires-Dist: packaging>=24.1
28
- Requires-Dist: torch==2.7.0
29
- Requires-Dist: torchaudio<=2.7.0
30
- Requires-Dist: torchvision<=0.22.0
31
- Requires-Dist: transformers==4.51.3
29
+ Requires-Dist: torch==2.8.0
30
+ Requires-Dist: torchaudio<=2.8.0
31
+ Requires-Dist: torchvision<=0.23.0
32
+ Requires-Dist: transformers==4.57.1
32
33
  Description-Content-Type: text/markdown
33
34
 
34
35