optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -21,6 +21,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling
21
21
  from ....configuration_utils import RBLNCompileConfig
22
22
  from ....modeling import RBLNModel
23
23
  from ....utils.logging import get_logger
24
+ from ...modeling_outputs import _validate_output_attentions, _validate_output_hidden_states
24
25
  from .configuration_siglip import RBLNSiglipVisionModelConfig
25
26
 
26
27
 
@@ -52,7 +53,7 @@ class _SiglipVisionModel(torch.nn.Module):
52
53
  interpolate_pos_encoding=self.interpolate_pos_encoding,
53
54
  output_attentions=self.output_attentions,
54
55
  )
55
- return tuple(x for x in enc_out if x is not None)
56
+ return enc_out
56
57
 
57
58
 
58
59
  class RBLNSiglipVisionModel(RBLNModel):
@@ -66,7 +67,9 @@ class RBLNSiglipVisionModel(RBLNModel):
66
67
  _tp_support = False
67
68
 
68
69
  @classmethod
69
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
70
+ def _wrap_model_if_needed(
71
+ cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig
72
+ ) -> torch.nn.Module:
70
73
  wrapper_cfg = {
71
74
  "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
72
75
  "output_hidden_states": rbln_config.output_hidden_states,
@@ -122,23 +125,22 @@ class RBLNSiglipVisionModel(RBLNModel):
122
125
  interpolate_pos_encoding: bool = False,
123
126
  **kwargs: Any,
124
127
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
125
- output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
126
- output_hidden_states = (
127
- output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
128
- )
129
-
130
- if output_attentions != self.rbln_config.output_attentions:
131
- raise ValueError(
132
- f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
133
- f"Please compile again with the correct argument."
134
- )
135
-
136
- if output_hidden_states != self.rbln_config.output_hidden_states:
137
- raise ValueError(
138
- f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
139
- f"Please compile again with the correct argument."
140
- )
141
-
128
+ """
129
+ Forward pass for the RBLN-optimized SigLIP vision model.
130
+
131
+ Args:
132
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
133
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
134
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
135
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
136
+ interpolate_pos_encoding (bool, defaults to False): Whether to interpolate the pre-trained position encodings.
137
+
138
+ Returns:
139
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
140
+ """
141
+
142
+ output_attentions = _validate_output_attentions(output_attentions, self.rbln_config)
143
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
142
144
  if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
143
145
  raise ValueError(
144
146
  f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
@@ -32,11 +32,6 @@ class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
32
32
  Raises:
33
33
  ValueError: If batch_size is not a positive integer.
34
34
  """
35
- super().__init__(**kwargs)
36
- self.batch_size = batch_size or 1
37
- if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
- raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
39
-
40
- self.image_size = image_size
35
+ super().__init__(batch_size=batch_size, image_size=image_size, **kwargs)
41
36
  self.output_hidden_states = output_hidden_states
42
37
  self.output_attentions = output_attentions
@@ -203,7 +203,7 @@ class _SwinBackbone(torch.nn.Module):
203
203
 
204
204
  class RBLNSwinBackbone(RBLNModel):
205
205
  @classmethod
206
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
206
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
207
207
  for layer in model.encoder.layers:
208
208
  for block in layer.blocks:
209
209
  block.get_attn_mask = types.MethodType(get_attn_mask, block)
@@ -278,6 +278,19 @@ class RBLNSwinBackbone(RBLNModel):
278
278
  output_hidden_states: bool = None,
279
279
  **kwargs,
280
280
  ) -> Union[Tuple, BackboneOutput]:
281
+ """
282
+ Forward pass for the RBLN-optimized Swin backbone model.
283
+
284
+ Args:
285
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
286
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
287
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
288
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
289
+
290
+ Returns:
291
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BackboneOutput object.
292
+ """
293
+
281
294
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
282
295
  logger.warning(
283
296
  f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
@@ -314,19 +327,19 @@ class RBLNSwinBackbone(RBLNModel):
314
327
  output = self.model[0](padded_pixel_values)
315
328
 
316
329
  feature_maps = ()
317
- for i in range(len(self.config.out_features)):
330
+ for _ in range(len(self.config.out_features)):
318
331
  feature_maps += (output.pop(0),)
319
332
 
320
333
  if self.rbln_config.output_hidden_states:
321
334
  hidden_states = ()
322
- for i in range(len(self.config.stage_names)):
335
+ for _ in range(len(self.config.stage_names)):
323
336
  hidden_states += (output.pop(0),)
324
337
  else:
325
338
  hidden_states = None
326
339
 
327
340
  if self.rbln_config.output_attentions:
328
341
  attentions = ()
329
- for i in range(len(self.config.depths)):
342
+ for _ in range(len(self.config.depths)):
330
343
  attentions += (output.pop(0),)
331
344
  else:
332
345
  attentions = None
@@ -68,7 +68,7 @@ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
68
68
  output_class = BaseModelOutputWithPastAndCrossAttentions
69
69
 
70
70
  @classmethod
71
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
71
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
72
72
  return T5EncoderWrapper(model)
73
73
 
74
74
  @classmethod
@@ -113,7 +113,7 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
113
113
  support_causal_attn = False
114
114
 
115
115
  @classmethod
116
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
116
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
117
117
  return T5Wrapper(
118
118
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
119
119
  )
@@ -39,7 +39,7 @@ class T5Wrapper:
39
39
 
40
40
  class T5EncoderWrapper(Seq2SeqEncoderWrapper):
41
41
  def __post_init__(self, model: nn.Module):
42
- self.n_layer = getattr(self.config, "num_layers")
42
+ self.n_layer = self.config.num_layers
43
43
  self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().block)
44
44
  self.num_heads = self.config.num_heads
45
45
  self.d_kv = self.config.d_kv
@@ -111,9 +111,9 @@ class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
111
111
  class T5Decoder(Seq2SeqDecoder):
112
112
  has_pos_emb = False
113
113
 
114
- def __post_init__(self, dec_max_seq_len: int = None):
115
- self.invert_attention_mask = self._original_mod.invert_attention_mask
116
- self._dec_position_bias = self.precompute_dec_position_bias(self._original_mod, dec_max_seq_len)
114
+ def __post_init__(self, model: nn.Module, dec_max_seq_len: int = None):
115
+ self.invert_attention_mask = model.invert_attention_mask
116
+ self._dec_position_bias = self.precompute_dec_position_bias(model, dec_max_seq_len)
117
117
 
118
118
  def precompute_dec_position_bias(self, model, dec_max_length):
119
119
  attn_layer = model.block[0].layer[0].SelfAttention
@@ -145,13 +145,12 @@ class T5Decoder(Seq2SeqDecoder):
145
145
  class T5Block(Seq2SeqDecoderLayer):
146
146
  def __init__(self, decoder_layer, self_attn):
147
147
  super().__init__(decoder_layer, self_attn, cross_attn=None)
148
- self.__post_init__()
149
148
 
150
- def __post_init__(self):
151
- self.self_attn_layer_norm = self._original_mod.layer[0].layer_norm
152
- self.encoder_attn_layer_norm = self._original_mod.layer[1].layer_norm
153
- self.cross_attn = T5CrossAttention(self._original_mod.layer[1].EncDecAttention)
154
- self.ff_layer = self._original_mod.layer[2]
149
+ def __post_init__(self, decoder_layer: nn.Module):
150
+ self.self_attn_layer_norm = decoder_layer.layer[0].layer_norm
151
+ self.encoder_attn_layer_norm = decoder_layer.layer[1].layer_norm
152
+ self.cross_attn = T5CrossAttention(decoder_layer.layer[1].EncDecAttention)
153
+ self.ff_layer = decoder_layer.layer[2]
155
154
 
156
155
  def pre_self_attn_layer_norm(self, hidden_states):
157
156
  return self.self_attn_layer_norm(hidden_states)
@@ -167,13 +166,13 @@ class T5Block(Seq2SeqDecoderLayer):
167
166
 
168
167
 
169
168
  class T5LayerSelfAttention(Seq2SeqSelfAttention):
170
- def __post_init__(self):
171
- self.q_proj = self._original_mod.q
172
- self.k_proj = self._original_mod.k
173
- self.v_proj = self._original_mod.v
174
- self.out_proj = self._original_mod.o
175
- self.num_heads = self._original_mod.n_heads
176
- self.head_dim = self._original_mod.key_value_proj_dim
169
+ def __post_init__(self, attn: nn.Module):
170
+ self.q_proj = attn.q
171
+ self.k_proj = attn.k
172
+ self.v_proj = attn.v
173
+ self.out_proj = attn.o
174
+ self.num_heads = attn.n_heads
175
+ self.head_dim = attn.key_value_proj_dim
177
176
  self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
178
177
 
179
178
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -153,7 +153,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
153
153
  return redirect(val)
154
154
 
155
155
  @classmethod
156
- def wrap_model_if_needed(
156
+ def _wrap_model_if_needed(
157
157
  self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
158
158
  ):
159
159
  return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
@@ -161,7 +161,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
161
161
  @classmethod
162
162
  @torch.inference_mode()
163
163
  def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
164
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
164
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
165
165
 
166
166
  enc_compile_config = rbln_config.compile_cfgs[0]
167
167
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -184,14 +184,6 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
184
184
  if "key_value_states" in name:
185
185
  context.mark_static_address(tensor)
186
186
 
187
- compiled_decoder = cls.compile(
188
- wrapped_model.decoder,
189
- dec_compile_config,
190
- create_runtimes=rbln_config.create_runtimes,
191
- device=rbln_config.device,
192
- example_inputs=dec_example_inputs,
193
- compile_context=context,
194
- )
195
187
  compiled_encoder = cls.compile(
196
188
  wrapped_model.encoder,
197
189
  enc_compile_config,
@@ -201,6 +193,15 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
201
193
  compile_context=context,
202
194
  )
203
195
 
196
+ compiled_decoder = cls.compile(
197
+ wrapped_model.decoder,
198
+ dec_compile_config,
199
+ create_runtimes=rbln_config.create_runtimes,
200
+ device=rbln_config.device,
201
+ example_inputs=dec_example_inputs,
202
+ compile_context=context,
203
+ )
204
+
204
205
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
205
206
 
206
207
  @classmethod
@@ -353,6 +354,20 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
353
354
  static_real_features: Optional[torch.Tensor] = None,
354
355
  **kwargs,
355
356
  ) -> SampleTSPredictionOutput:
357
+ """
358
+ Generate pass for the RBLN-optimized Time Series Transformer model for time series forecasting.
359
+
360
+ Args:
361
+ past_values (torch.FloatTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size)): Past values of the time series, that serve as context in order to predict the future.
362
+ past_time_features (torch.FloatTensor of shape (batch_size, sequence_length, num_features)): Required time features, which the model internally will add to past_values.
363
+ future_time_features (torch.FloatTensor of shape (batch_size, prediction_length, num_features)): Required time features for the prediction window, which the model internally will add to future_values.
364
+ past_observed_mask (torch.BoolTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size), optional): Boolean mask to indicate which past_values were observed and which were missing.
365
+ static_categorical_features (torch.LongTensor of shape (batch_size, number of static categorical features), optional): Optional static categorical features for which the model will learn an embedding, which it will add to the values of the time series.
366
+ static_real_features (torch.FloatTensor of shape (batch_size, number of static real features), optional): Optional static real features which the model will add to the values of the time series.
367
+
368
+ Returns:
369
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SampleTSPredictionOutput object.
370
+ """
356
371
  self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
357
372
 
358
373
  outputs = self.encoder(
@@ -140,7 +140,6 @@ class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
140
140
  class TimeSeriesTransformersDecoder(nn.Module):
141
141
  def __init__(self, model, layers, **kwargs):
142
142
  super().__init__()
143
- self._original_mod = model
144
143
  self.config = model.config
145
144
  self.layers = nn.ModuleList(layers)
146
145
  self.value_embedding = model.value_embedding
@@ -190,7 +189,6 @@ class TimeSeriesTransformersDecoder(nn.Module):
190
189
  class TimeSeriesTransformersDecoderLayer(nn.Module):
191
190
  def __init__(self, decoder_layer, self_attn, cross_attn):
192
191
  super().__init__()
193
- self._original_mod = decoder_layer
194
192
  self.self_attn = self_attn
195
193
  self.encoder_attn = cross_attn
196
194
  self.embed_dim = decoder_layer.embed_dim
@@ -245,7 +243,6 @@ class TimeSeriesTransformersDecoderLayer(nn.Module):
245
243
  class TimeSeriesTransformersAttention(nn.Module):
246
244
  def __init__(self, attn, num_parallel_samples):
247
245
  super().__init__()
248
- self._original_mod = attn
249
246
  self.q_proj = attn.q_proj
250
247
  self.k_proj = attn.k_proj
251
248
  self.v_proj = attn.v_proj
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import ImageClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForImageClassification
16
21
 
17
22
 
@@ -23,3 +28,17 @@ class RBLNViTForImageClassification(RBLNModelForImageClassification):
23
28
  on RBLN devices, supporting image classification with transformer-based architectures
24
29
  that process images as sequences of patches.
25
30
  """
31
+
32
+ def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[ImageClassifierOutput, Tuple]:
33
+ """
34
+ Forward pass for the RBLN-optimized Vision Transformer model for image classification.
35
+
36
+ Args:
37
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)):
38
+ The tensors corresponding to the input images.
39
+
40
+ Returns:
41
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns an ImageClassifierOutput object.
42
+
43
+ """
44
+ return super().forward(pixel_values, **kwargs)
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ...configuration_generic import RBLNModelForMaskedLMConfig
15
+ from typing import Any, Optional
16
16
 
17
+ from ....configuration_utils import RBLNModelConfig
17
18
 
18
- class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
19
+
20
+ class RBLNWav2Vec2ForCTCConfig(RBLNModelConfig):
19
21
  """
20
22
  Configuration class for RBLNWav2Vec2ForCTC.
21
23
 
@@ -23,4 +25,14 @@ class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
23
25
  RBLN-optimized Wav2Vec2 models for Connectionist Temporal Classification (CTC) tasks.
24
26
  """
25
27
 
26
- rbln_model_input_names = ["input_values"]
28
+ def __init__(
29
+ self,
30
+ max_seq_len: Optional[int] = None,
31
+ batch_size: Optional[int] = None,
32
+ **kwargs: Any,
33
+ ):
34
+ super().__init__(**kwargs)
35
+ self.max_seq_len = max_seq_len
36
+ self.batch_size = batch_size or 1
37
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
@@ -13,13 +13,21 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import TYPE_CHECKING, Optional, Union
17
+
16
18
  import torch
17
- from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
19
+ from transformers import AutoModelForCTC, Wav2Vec2Config, Wav2Vec2ForCTC
20
+ from transformers.modeling_outputs import CausalLMOutput
18
21
 
19
- from ...modeling_generic import RBLNModelForMaskedLM
22
+ from ....configuration_utils import RBLNCompileConfig
23
+ from ....modeling import RBLNModel
20
24
  from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
21
25
 
22
26
 
27
+ if TYPE_CHECKING:
28
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
29
+
30
+
23
31
  class _Wav2Vec2(torch.nn.Module):
24
32
  def __init__(self, model: "Wav2Vec2ForCTC"):
25
33
  super().__init__()
@@ -30,13 +38,10 @@ class _Wav2Vec2(torch.nn.Module):
30
38
  return self.model.lm_head(output[0])
31
39
 
32
40
 
33
- class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
41
+ class RBLNWav2Vec2ForCTC(RBLNModel):
34
42
  """
35
43
  Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
36
44
 
37
- This model inherits from [`RBLNModelForMaskedLM`]. Check the superclass documentation for the generic methods the
38
- library implements for all its model.
39
-
40
45
  It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
41
46
 
42
47
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
@@ -44,9 +49,56 @@ class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
44
49
  """
45
50
 
46
51
  main_input_name = "input_values"
47
- auto_model_class = AutoModelForMaskedLM
52
+ auto_model_class = AutoModelForCTC
48
53
  rbln_dtype = "float32"
49
54
 
50
55
  @classmethod
51
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
56
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
52
57
  return _Wav2Vec2(model).eval()
58
+
59
+ @classmethod
60
+ def _update_rbln_config(
61
+ cls,
62
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
63
+ model: Optional["PreTrainedModel"] = None,
64
+ model_config: "Wav2Vec2Config" = None,
65
+ rbln_config: Optional[RBLNWav2Vec2ForCTCConfig] = None,
66
+ ) -> RBLNWav2Vec2ForCTCConfig:
67
+ if rbln_config.max_seq_len is None:
68
+ for tokenizer in preprocessors:
69
+ if hasattr(tokenizer, "model_max_length"):
70
+ rbln_config.max_seq_len = tokenizer.model_max_length
71
+ break
72
+ if rbln_config.max_seq_len is None:
73
+ raise ValueError("`rbln_max_seq_len` should be specified!")
74
+
75
+ rbln_compile_config = RBLNCompileConfig(
76
+ input_info=[
77
+ (
78
+ "input_values",
79
+ [
80
+ rbln_config.batch_size,
81
+ rbln_config.max_seq_len,
82
+ ],
83
+ "float32",
84
+ )
85
+ ]
86
+ )
87
+
88
+ rbln_config.set_compile_cfgs([rbln_compile_config])
89
+ return rbln_config
90
+
91
+ def forward(
92
+ self, input_values: torch.Tensor, return_dict: Optional[bool] = None, **kwargs
93
+ ) -> Union[CausalLMOutput, tuple]:
94
+ """
95
+ Forward pass for the RBLN-optimized Wav2Vec2 model for Connectionist Temporal Classification (CTC).
96
+
97
+ Args:
98
+ input_values (torch.FloatTensor of shape (batch_size, sequence_length)): Float values of input raw speech waveform. Values can be obtained by loading a .flac or .wav audio file into an array of type List[float] or a numpy.ndarray, e.g. via the soundfile library (pip install soundfile). To prepare the array into input_values, the AutoProcessor should be used for padding and conversion into a tensor of type torch.FloatTensor.
99
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
100
+
101
+ Returns:
102
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CausalLMOutput object.
103
+ """
104
+ return super().forward(input_values=input_values, return_dict=return_dict, **kwargs)
@@ -31,29 +31,63 @@ Generation utilities for Whisper.
31
31
  Modified from `transformers.models.whisper.generation_whisper.py`
32
32
  """
33
33
 
34
+ from typing import Any, Dict, Optional, Union
35
+
34
36
  import torch
35
37
  import transformers
36
38
  from packaging import version
37
39
  from transformers import GenerationMixin
40
+ from transformers.generation.configuration_utils import GenerationConfig
41
+ from transformers.modeling_outputs import ModelOutput
38
42
  from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
39
43
 
40
44
 
41
45
  class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
42
- def generate(self, *args, generation_config=None, **kwargs):
43
- num_beams = kwargs.get(
44
- "num_beams",
45
- generation_config.num_beams
46
- if hasattr(generation_config, "num_beams") and generation_config.num_beams is not None
47
- else 1,
48
- )
49
- if num_beams > 1:
50
- raise ValueError(
51
- f"Beam search is not supported in RBLNWhisperGenerationMixin. "
52
- f"Received num_beams={num_beams}, but only num_beams=1 is allowed. "
53
- f"Please set num_beams=1 for greedy search or adjust your configuration."
54
- )
46
+ def generate(
47
+ self,
48
+ input_features: Optional[torch.Tensor] = None,
49
+ attention_mask: Optional[torch.Tensor] = None,
50
+ generation_config: Optional[GenerationConfig] = None,
51
+ return_segments: Optional[bool] = None,
52
+ return_timestamps: Optional[bool] = None,
53
+ return_token_timestamps: Optional[bool] = None,
54
+ **kwargs,
55
+ ) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
56
+ """
57
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
58
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
59
+
60
+ Args:
61
+ input_features(torch.Tensor, optional): The input features to the model.
62
+ attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
63
+ generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
64
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
65
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
66
+ return_segments(bool, optional): Whether to return segments.
67
+ return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
68
+ return_token_timestamps(bool, optional): Whether to return token timestamps.
69
+ kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
70
+
71
+ Returns:
72
+ Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
73
+ """
74
+ if kwargs.get("num_beams", None) is not None:
75
+ if kwargs.get("num_beams") != 1:
76
+ raise ValueError(
77
+ "Beam search is not supported in RBLNWhisperGenerationMixin. "
78
+ "Received num_beams={num_beams}, but only num_beams=1 is allowed. "
79
+ "Please set num_beams=1 for greedy search or adjust your configuration."
80
+ )
55
81
 
56
- return super().generate(*args, **kwargs)
82
+ return super().generate(
83
+ input_features,
84
+ attention_mask=attention_mask,
85
+ generation_config=generation_config,
86
+ return_segments=return_segments,
87
+ return_timestamps=return_timestamps,
88
+ return_token_timestamps=return_token_timestamps,
89
+ **kwargs,
90
+ )
57
91
 
58
92
  def _postprocess_outputs(
59
93
  self,
@@ -203,7 +203,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
203
203
  raise NotImplementedError
204
204
 
205
205
  @classmethod
206
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
206
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
207
207
  return WhisperWrapper(
208
208
  model,
209
209
  use_attention_mask=rbln_config.use_attention_mask,
@@ -213,7 +213,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
213
213
  @classmethod
214
214
  @torch.inference_mode()
215
215
  def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
216
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
216
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
217
217
 
218
218
  enc_compile_config = rbln_config.compile_cfgs[0]
219
219
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -154,7 +154,6 @@ class WhisperDecoderWrapper(torch.nn.Module):
154
154
  class WhisperDecoder(nn.Module):
155
155
  def __init__(self, model, layers, **kwargs):
156
156
  super().__init__()
157
- self._original_mod = model
158
157
  self.layers = nn.ModuleList(layers)
159
158
  self.embed_tokens = model.embed_tokens
160
159
  self.layer_norm = model.layer_norm
@@ -210,7 +209,6 @@ class WhisperDecoder(nn.Module):
210
209
  class WhisperDecoderLayer(nn.Module):
211
210
  def __init__(self, decoder_layer, self_attn, cross_attn):
212
211
  super().__init__()
213
- self._original_mod = decoder_layer
214
212
  self.self_attn = self_attn
215
213
  self.encoder_attn = cross_attn
216
214
  self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
@@ -263,7 +261,6 @@ class WhisperDecoderLayer(nn.Module):
263
261
  class WhisperAttention(nn.Module):
264
262
  def __init__(self, attn):
265
263
  super().__init__()
266
- self._original_mod = attn
267
264
  self.q_proj = attn.q_proj
268
265
  self.k_proj = attn.k_proj
269
266
  self.v_proj = attn.v_proj