optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -12,16 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import torch
16
- import torch.nn as nn
17
- from transformers import PreTrainedModel
18
15
 
19
- from ..decoderonly.decoderonly_architecture import (
20
- DecoderOnlyAttention,
21
- DecoderOnlyLayer,
22
- DecoderOnlyWrapper,
23
- RotaryEmbedding,
24
- )
16
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyWrapper
25
17
 
26
18
 
27
19
  class Qwen3Wrapper(DecoderOnlyWrapper):
@@ -37,239 +29,3 @@ class Qwen3Attention(DecoderOnlyAttention):
37
29
  self.o_proj = self._original_mod.o_proj
38
30
  self.q_norm = self._original_mod.q_norm
39
31
  self.k_norm = self._original_mod.k_norm
40
-
41
-
42
- class Qwen3ModelWrapper(nn.Module):
43
- def __init__(
44
- self,
45
- model,
46
- attn_impl=None,
47
- use_inputs_embeds=None,
48
- use_attention_mask=None,
49
- use_rotary_emb=None,
50
- cache_impl=None,
51
- kvcache_partition_len=None,
52
- max_seq_len=None,
53
- kvcache_block_size=None,
54
- sliding_window=None,
55
- sliding_window_layers=None,
56
- ):
57
- super().__init__()
58
- self.config = model.config
59
-
60
- if use_rotary_emb:
61
- rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
62
- if isinstance(rotary_embs, tuple):
63
- self.rotary_emb_global, self.rotary_emb_local = rotary_embs
64
- else:
65
- self.rotary_emb = rotary_embs
66
- else:
67
- self.rotary_emb = None
68
-
69
- self._original_mod = model
70
- self.use_inputs_embeds = use_inputs_embeds
71
- self.attn_impl = attn_impl
72
- self.cache_impl = cache_impl
73
- self.use_attention_mask = use_attention_mask
74
- self.kvcache_partition_len = kvcache_partition_len
75
- self.kvcache_block_size = kvcache_block_size
76
- self.max_seq_len = max_seq_len
77
- self.sliding_window = sliding_window
78
- self.sliding_window_layers = sliding_window_layers
79
- self.model = self.convert_to_rbln_model(model)
80
-
81
- def get_rotary_emb(self, max_seq_len):
82
- return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
83
-
84
- def convert_to_rbln_model(self, base_model: PreTrainedModel):
85
- for layer_idx, layer in enumerate(base_model.layers):
86
- is_sliding = layer_idx in self.sliding_window_layers
87
- new_self_attn = Qwen3Attention(
88
- layer.self_attn,
89
- self.use_attention_mask if not is_sliding else True,
90
- use_position_ids=None,
91
- kvcache_block_size=self.sliding_window
92
- if layer_idx in self.sliding_window_layers
93
- else self.kvcache_block_size,
94
- is_sliding=is_sliding,
95
- attn_impl=self.attn_impl if not is_sliding else "eager",
96
- kvcache_partition_len=self.kvcache_partition_len,
97
- )
98
- base_model.layers[layer_idx] = DecoderOnlyLayer(layer, new_self_attn)
99
-
100
- return base_model
101
-
102
- @property
103
- def hidden_multiplier(self):
104
- return 1
105
-
106
- def get_last_layernorm(self) -> nn.LayerNorm:
107
- return self._original_mod.norm
108
-
109
- def get_embedding(self) -> nn.Embedding:
110
- return self._original_mod.embed_tokens
111
-
112
- def get_pos_embedding(self) -> nn.Embedding:
113
- raise NotImplementedError(
114
- "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
115
- )
116
-
117
- def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
118
- if self.attn_impl not in ["flash_attn"]:
119
- raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
120
- partition_len = self.kvcache_partition_len
121
- num_partition = max_seq_len // partition_len
122
-
123
- cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
124
- pidx = torch.arange(num_partition)
125
- cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
126
- return cache_pos_for_partitions
127
-
128
- def get_local_cache_positions(self, position_ids, query_position):
129
- max_cache_len = self.model.config.sliding_window
130
- valid_input_len = 1 if query_position is None else query_position + 1
131
- cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
132
- cache_offset = (
133
- torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
134
- ) # cache offset for next steps
135
-
136
- return cache_seq_len, cache_offset
137
-
138
- def prepare_forward_args(self, *args):
139
- args = list(args)
140
- input_ids = None if self.use_inputs_embeds else args.pop(0)
141
- inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
142
- cache_position = args.pop(0)
143
- global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
144
- local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
145
- query_position = args.pop(0) if self.sliding_window else None
146
- attention_mask = args.pop(0) if self.use_attention_mask else None
147
- position_ids = None
148
- past_key_values = args
149
-
150
- if len(past_key_values) != 2 * self.config.num_hidden_layers:
151
- raise ValueError(
152
- f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.config.num_hidden_layers}"
153
- )
154
-
155
- # [key, value] * n_layer -> ( (key, value) ) * n_layer
156
- # cache shape : batch, n_heads, 1, max_seq_len, head_dim
157
- _past_key_values = []
158
- for i in range(self.config.num_hidden_layers):
159
- key_states = past_key_values[i * 2]
160
- value_states = past_key_values[i * 2 + 1]
161
- past_key_value = [key_states, value_states]
162
- _past_key_values.append(past_key_value)
163
- past_key_values = _past_key_values
164
-
165
- if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
166
- rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
167
- else:
168
- rotary_emb = self.rotary_emb
169
-
170
- return (
171
- input_ids,
172
- inputs_embeds,
173
- cache_position,
174
- global_block_tables,
175
- local_block_tables,
176
- attention_mask,
177
- position_ids,
178
- query_position,
179
- past_key_values,
180
- rotary_emb,
181
- )
182
-
183
- def forward(self, *args):
184
- (
185
- input_ids,
186
- inputs_embeds,
187
- cache_position,
188
- global_block_tables,
189
- local_block_tables,
190
- attention_mask,
191
- position_ids,
192
- query_position,
193
- past_key_values,
194
- rotary_emb,
195
- ) = self.prepare_forward_args(*args)
196
-
197
- # retrieve input_ids and inputs_embeds
198
- if (input_ids is None) ^ (inputs_embeds is not None):
199
- raise ValueError(
200
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
201
- )
202
-
203
- # embed positions
204
- if inputs_embeds is None:
205
- inputs_embeds = self.get_embedding()(input_ids)
206
-
207
- hidden_states = inputs_embeds * self.hidden_multiplier
208
-
209
- # get cos,sin vector if needed
210
- position_ids = position_ids if position_ids is not None else cache_position
211
- if rotary_emb is not None:
212
- if isinstance(rotary_emb, torch.Tensor):
213
- cos = rotary_emb[0]
214
- sin = rotary_emb[1]
215
- else:
216
- cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
217
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
218
- else:
219
- batch_size = inputs_embeds.shape[0]
220
- if position_ids.shape[0] > 1:
221
- position_embeds = []
222
- for b_idx in range(batch_size):
223
- position_embed = self.get_pos_embedding()(position_ids[b_idx])
224
- position_embeds.append(position_embed)
225
-
226
- position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
227
- else:
228
- position_embeds = self.get_pos_embedding()(position_ids)
229
- hidden_states = hidden_states + position_embeds
230
- cos, sin = None, None
231
-
232
- # Get sequence positions for flash attention
233
- if self.attn_impl == "flash_attn":
234
- seq_positions = cache_position[:, 0]
235
- seq_positions = self.convert_sequence_positions_for_flash_attn(
236
- seq_positions=seq_positions, max_seq_len=self.max_seq_len
237
- )
238
- else:
239
- seq_positions = cache_position[:, :1]
240
-
241
- # Get local cache positions for sliding window layers
242
- if len(self.sliding_window_layers) > 0:
243
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
244
-
245
- for layer_idx, layer in enumerate(self.model.layers):
246
- is_sliding = True if layer_idx in self.sliding_window_layers else False
247
- hidden_states = layer(
248
- hidden_states=hidden_states,
249
- attention_mask=attention_mask,
250
- seq_positions=sliding_cache_pos if is_sliding else seq_positions,
251
- past_key_values=past_key_values,
252
- cos=cos,
253
- sin=sin,
254
- block_tables=local_block_tables if is_sliding else global_block_tables,
255
- )
256
-
257
- hidden_states = self.get_last_layernorm()(hidden_states)
258
- return hidden_states
259
-
260
-
261
- def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
262
- """Slice cos[cache_position], sin[cache_position] vector for the query."""
263
- if cache_position.shape[0] > 1:
264
- cos_all = []
265
- sin_all = []
266
- for i in range(cache_position.shape[0]):
267
- cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
268
- sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
269
- cos = torch.cat(cos_all, dim=0)
270
- sin = torch.cat(sin_all, dim=0)
271
- else:
272
- cos = cos[cache_position].unsqueeze(unsqueeze_dim)
273
- sin = sin[cache_position].unsqueeze(unsqueeze_dim)
274
-
275
- return cos, sin
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Optional
17
+
16
18
  from ...configuration_generic import RBLNModelForImageClassificationConfig
17
19
 
18
20
 
@@ -23,3 +25,18 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
23
25
  This configuration class stores the configuration parameters specific to
24
26
  RBLN-optimized ResNet models for image classification tasks.
25
27
  """
28
+
29
+ def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
30
+ """
31
+ Args:
32
+ image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
33
+ Can be an integer for square images or a tuple (height, width).
34
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
35
+ output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers.
36
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
+
38
+ Raises:
39
+ ValueError: If batch_size is not a positive integer.
40
+ """
41
+ super().__init__(**kwargs)
42
+ self.output_hidden_states = output_hidden_states
@@ -13,7 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
20
+
16
21
  from ...modeling_generic import RBLNModelForImageClassification
22
+ from .configuration_resnet import RBLNResNetForImageClassificationConfig
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
17
27
 
18
28
 
19
29
  class RBLNResNetForImageClassification(RBLNModelForImageClassification):
@@ -24,3 +34,66 @@ class RBLNResNetForImageClassification(RBLNModelForImageClassification):
24
34
  on RBLN devices, supporting image classification with convolutional neural networks
25
35
  designed for computer vision tasks.
26
36
  """
37
+
38
+ @classmethod
39
+ def _update_rbln_config(
40
+ cls,
41
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
42
+ model: Optional["PreTrainedModel"] = None,
43
+ model_config: Optional["PretrainedConfig"] = None,
44
+ rbln_config: Optional["RBLNResNetForImageClassificationConfig"] = None,
45
+ ) -> "RBLNResNetForImageClassificationConfig":
46
+ if rbln_config.output_hidden_states is None:
47
+ rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
48
+
49
+ rbln_config = super()._update_rbln_config(
50
+ preprocessors=preprocessors,
51
+ model=model,
52
+ model_config=model_config,
53
+ rbln_config=rbln_config,
54
+ )
55
+
56
+ return rbln_config
57
+
58
+ @classmethod
59
+ def _wrap_model_if_needed(
60
+ cls, model: torch.nn.Module, rbln_config: "RBLNResNetForImageClassificationConfig"
61
+ ) -> torch.nn.Module:
62
+ class _ResNetForImageClassification(torch.nn.Module):
63
+ def __init__(self, model: torch.nn.Module, output_hidden_states: bool):
64
+ super().__init__()
65
+ self.model = model
66
+ self.output_hidden_states = output_hidden_states
67
+
68
+ def forward(self, *args, **kwargs):
69
+ output = self.model(*args, output_hidden_states=self.output_hidden_states, **kwargs)
70
+ return output
71
+
72
+ return _ResNetForImageClassification(model, rbln_config.output_hidden_states)
73
+
74
+ def forward(
75
+ self, pixel_values: torch.Tensor, output_hidden_states: bool = None, return_dict: bool = None, **kwargs
76
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
77
+ """
78
+ Foward pass for the RBLN-optimized ResNet model for image classification.
79
+
80
+ Args:
81
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
82
+ output_hidden_states (bool, *optional*, defaults to False): Whether or not to return the hidden states of all layers.
83
+ See hidden_states under returned tensors for more details.
84
+ return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
85
+
86
+ Returns:
87
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
88
+ """
89
+ output_hidden_states = (
90
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
91
+ )
92
+
93
+ if output_hidden_states != self.rbln_config.output_hidden_states:
94
+ raise ValueError(
95
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
96
+ f"Please compile again with the correct argument."
97
+ )
98
+
99
+ return super().forward(pixel_values=pixel_values, return_dict=return_dict, **kwargs)
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassification
16
21
 
17
22
 
@@ -26,6 +31,19 @@ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
26
31
 
27
32
  rbln_model_input_names = ["input_ids", "attention_mask"]
28
33
 
34
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Union[Tuple, MaskedLMOutput]:
35
+ """
36
+ Forward pass for the RBLN-optimized RoBERTa model for masked language modeling tasks.
37
+
38
+ Args:
39
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
40
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
41
+
42
+ Returns:
43
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
44
+ """
45
+ return super().forward(input_ids, attention_mask, **kwargs)
46
+
29
47
 
30
48
  class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
31
49
  """
@@ -37,3 +55,18 @@ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
37
55
  """
38
56
 
39
57
  rbln_model_input_names = ["input_ids", "attention_mask"]
58
+
59
+ def forward(
60
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
61
+ ) -> Union[Tuple, SequenceClassifierOutput]:
62
+ """
63
+ Forward pass for the RBLN-optimized RoBERTa model for sequence classification tasks.
64
+
65
+ Args:
66
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
67
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
68
+
69
+ Returns:
70
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
71
+ """
72
+ return super().forward(input_ids, attention_mask, **kwargs)
@@ -12,11 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
16
-
17
- import rebel
15
+ from typing import Any, Optional
18
16
 
19
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.deprecation import deprecate_kwarg
20
19
  from ....utils.logging import get_logger
21
20
 
22
21
 
@@ -24,14 +23,18 @@ logger = get_logger()
24
23
 
25
24
 
26
25
  class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
26
+ support_paged_attention = None
27
+
28
+ @deprecate_kwarg(old_name="pad_token_id", version="0.10.0")
27
29
  def __init__(
28
30
  self,
29
31
  batch_size: Optional[int] = None,
30
32
  enc_max_seq_len: Optional[int] = None,
31
33
  dec_max_seq_len: Optional[int] = None,
32
34
  use_attention_mask: Optional[bool] = None,
33
- pad_token_id: Optional[int] = None,
34
- **kwargs: Dict[str, Any],
35
+ kvcache_num_blocks: Optional[int] = None,
36
+ kvcache_block_size: Optional[int] = None,
37
+ **kwargs: Any,
35
38
  ):
36
39
  """
37
40
  Args:
@@ -39,9 +42,11 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
39
42
  enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
40
43
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
41
44
  use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
42
- This is automatically set to True for RBLN-CA02 devices.
43
- pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
44
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+ kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
46
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
47
+ kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
48
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
49
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
50
 
46
51
  Raises:
47
52
  ValueError: If batch_size is not a positive integer.
@@ -55,12 +60,12 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
55
60
  self.dec_max_seq_len = dec_max_seq_len
56
61
 
57
62
  self.use_attention_mask = use_attention_mask
58
- npu = self.npu or rebel.get_npu_name()
59
- if npu == "RBLN-CA02":
60
- if self.use_attention_mask is False:
61
- logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
- self.use_attention_mask = True
63
- else:
64
- self.use_attention_mask = self.use_attention_mask or False
65
63
 
66
- self.pad_token_id = pad_token_id
64
+ if self.support_paged_attention:
65
+ self.kvcache_num_blocks = kvcache_num_blocks
66
+ self.kvcache_block_size = kvcache_block_size
67
+ else:
68
+ if kvcache_num_blocks is not None or kvcache_block_size is not None:
69
+ raise ValueError(
70
+ "You cannot set kvcache_num_blocks or kvcache_block_size as paged attention is not supported for the model."
71
+ )
@@ -20,7 +20,9 @@ import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
23
+ from transformers.generation.configuration_utils import GenerationConfig
24
+ from transformers.generation.utils import GenerationMixin
25
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
24
26
 
25
27
  from ....configuration_utils import RBLNCompileConfig
26
28
  from ....modeling import RBLNModel
@@ -32,13 +34,13 @@ from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
32
34
  logger = get_logger(__name__)
33
35
 
34
36
  if TYPE_CHECKING:
35
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, GenerationConfig, PretrainedConfig
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
36
38
 
37
39
 
38
40
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
39
41
  mandatory_members = ["main_input_name"]
40
42
 
41
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
43
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
42
44
  output = super().forward(*args, **kwargs)
43
45
  return BaseModelOutput(last_hidden_state=output)
44
46
 
@@ -83,7 +85,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
83
85
  decoding_step = cache_position[b_idx].item()
84
86
  if not (0 <= decoding_step < self.dec_max_seq_len):
85
87
  raise ValueError(
86
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
88
+ f"Decoding step {decoding_step} out of bounds for decoder_max_seq_len ({self.dec_max_seq_len})."
87
89
  )
88
90
  decoder_attention_mask[b_idx, : decoding_step + 1] = 1
89
91
 
@@ -101,7 +103,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
101
103
  return Seq2SeqLMOutput(logits=lm_logits)
102
104
 
103
105
 
104
- class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
106
+ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
105
107
  """
106
108
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
107
109
  This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -117,6 +119,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
117
119
  main_input_name = "input_ids"
118
120
  auto_model_class = AutoModelForSeq2SeqLM
119
121
  support_causal_attn = None
122
+ _is_stateful = False
120
123
 
121
124
  def __post_init__(self, **kwargs):
122
125
  batch_size = self.rbln_config.batch_size
@@ -138,7 +141,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
138
141
  @classmethod
139
142
  @torch.inference_mode()
140
143
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
141
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
144
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
142
145
 
143
146
  enc_compile_config = rbln_config.compile_cfgs[0]
144
147
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -181,6 +184,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
181
184
 
182
185
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
183
186
 
187
+ @classmethod
188
+ def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
189
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
190
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
191
+
192
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
193
+ raise NotImplementedError(
194
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
195
+ )
196
+
197
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
198
+ raise NotImplementedError(
199
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
200
+ )
201
+
184
202
  @classmethod
185
203
  def _update_rbln_config(
186
204
  cls,
@@ -204,12 +222,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
204
222
  model_config, "max_position_embeddings", None
205
223
  )
206
224
 
207
- pad_token_id = getattr(model_config, "pad_token_id", None)
208
- pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
209
- pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
210
- pad_token_id = pad_token_id or -1
211
- rbln_config.pad_token_id = pad_token_id
212
-
213
225
  if rbln_config.enc_max_seq_len is None:
214
226
  enc_max_seq_len = max_position_embeddings
215
227
  for tokenizer in preprocessors:
@@ -238,6 +250,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
238
250
  if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
239
251
  raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
240
252
 
253
+ if rbln_config.support_paged_attention:
254
+ cls._update_paged_attention_config(model_config, rbln_config)
255
+
241
256
  # model input info
242
257
  enc_input_info = [
243
258
  ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
@@ -310,6 +325,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
310
325
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
311
326
 
312
327
  rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
328
+
313
329
  return rbln_config
314
330
 
315
331
  @classmethod
@@ -411,7 +427,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
411
427
  inputs_tensor = torch.nn.functional.pad(
412
428
  inputs_tensor,
413
429
  (0, self.rbln_config.enc_max_seq_len - input_len),
414
- value=self.rbln_config.pad_token_id,
430
+ value=self.config.pad_token_id,
415
431
  )
416
432
  model_kwargs["attention_mask"] = torch.nn.functional.pad(
417
433
  model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
@@ -430,3 +446,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
430
446
  model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
431
447
 
432
448
  return model_kwargs
449
+
450
+ def generate(
451
+ self,
452
+ input_ids: torch.LongTensor,
453
+ attention_mask: Optional[torch.LongTensor] = None,
454
+ generation_config: Optional[GenerationConfig] = None,
455
+ **kwargs,
456
+ ) -> Union[ModelOutput, torch.LongTensor]:
457
+ """
458
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
459
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
460
+
461
+ Args:
462
+ input_ids (torch.LongTensor): The input ids to the model.
463
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
464
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
465
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
466
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
467
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
468
+
469
+ Returns:
470
+ Generates sequences of token ids for models with a language modeling head.
471
+ """
472
+ if generation_config is not None:
473
+ kwargs["generation_config"] = generation_config
474
+ if attention_mask is not None:
475
+ kwargs["attention_mask"] = attention_mask
476
+
477
+ return super().generate(input_ids, **kwargs)
@@ -31,7 +31,7 @@ class Seq2SeqWrapper:
31
31
  Args:
32
32
  model (nn.Module): The Seq2Seq model to wrap.
33
33
  enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
34
- **kwargs: Additional arguments to pass to the decoder wrapper.
34
+ kwargs: Additional arguments to pass to the decoder wrapper.
35
35
  """
36
36
 
37
37
  def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
@@ -125,7 +125,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
125
125
 
126
126
  Args:
127
127
  model (nn.Module): The Seq2Seq model containing the decoder.
128
- **kwargs: Additional arguments for decoder configuration.
128
+ kwargs: Additional arguments for decoder configuration.
129
129
  """
130
130
 
131
131
  def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
@@ -12,9 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_siglip import (
16
- RBLNSiglipVisionModelConfig,
17
- )
18
- from .modeling_siglip import (
19
- RBLNSiglipVisionModel,
20
- )
15
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
16
+ from .modeling_siglip import RBLNSiglipVisionModel