optimum-rbln 0.7.4a0__py3-none-any.whl → 0.7.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,27 +16,29 @@ from typing import Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
- from transformers.modeling_attn_mask_utils import (
20
- _prepare_4d_causal_attention_mask,
21
- )
22
19
  from transformers.modeling_outputs import (
23
20
  BaseModelOutput,
24
21
  Seq2SeqLMOutput,
25
22
  )
26
23
  from transformers.utils import logging
27
24
 
28
- from ....ops import register_rbln_custom_cache_update, register_rbln_custom_paged_add_softmax_attention
25
+ from ....ops import (
26
+ register_rbln_custom_cache_update,
27
+ register_rbln_custom_paged_attention,
28
+ register_rbln_custom_paged_causal_attention,
29
+ )
29
30
 
30
31
 
31
32
  logger = logging.get_logger(__name__)
32
33
 
33
34
 
34
35
  class WhisperWrapper:
35
- def __init__(self, model, rbln_token_timestamps):
36
+ def __init__(self, model, use_attention_mask, rbln_token_timestamps):
36
37
  register_rbln_custom_cache_update()
37
- register_rbln_custom_paged_add_softmax_attention()
38
38
  self.encoder = WhisperEncoderWrapper(model)
39
- self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
39
+ self.decoder = WhisperDecoderWrapper(
40
+ model, use_attention_mask=use_attention_mask, output_attentions=rbln_token_timestamps
41
+ )
40
42
 
41
43
 
42
44
  class WhisperEncoderWrapper(torch.nn.Module):
@@ -57,6 +59,7 @@ class WhisperEncoderWrapper(torch.nn.Module):
57
59
  def forward(
58
60
  self,
59
61
  input_features: Optional[torch.LongTensor],
62
+ b_idx: torch.Tensor,
60
63
  cross_key_values: torch.Tensor,
61
64
  ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
62
65
  # 1. get encoder last_hidden_states
@@ -76,21 +79,34 @@ class WhisperEncoderWrapper(torch.nn.Module):
76
79
  cross_kv = torch.stack(cross_kv, dim=0)
77
80
 
78
81
  # 3. update cross_attention's past_key_value to the device-dram for optimization.
79
- bidx = torch.tensor(0, dtype=torch.int16)
80
- axis = torch.tensor(1, dtype=torch.int16)
81
- enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
82
+ batch_axis = torch.tensor(1, dtype=torch.int16)
83
+ enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
82
84
 
83
85
  return enc_output
84
86
 
85
87
 
86
88
  class WhisperDecoderWrapper(torch.nn.Module):
87
- def __init__(self, model, output_attentions: bool = False):
89
+ def __init__(self, model, use_attention_mask: bool = True, output_attentions: bool = False, **kwargs):
88
90
  super().__init__()
89
91
  self.config = model.config
90
- self.num_layers = self.config.decoder_layers
91
92
  self.proj_out = model.proj_out
92
- self.decoder = self.convert_to_rbln_conditional_generation(model)
93
+ self.use_attention_mask = use_attention_mask
93
94
  self.output_attentions = output_attentions
95
+ self.__post_init__(model, **kwargs)
96
+
97
+ def __post_init__(self, model: nn.Module, **kwargs):
98
+ """
99
+ Post-initialization to extract and configure encoder-related attributes.
100
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
101
+ by subclasses to modify or add custom attributes as necessary.
102
+ """
103
+ if self.use_attention_mask:
104
+ register_rbln_custom_paged_attention()
105
+ else:
106
+ register_rbln_custom_paged_causal_attention()
107
+
108
+ self.num_layers = self.config.decoder_layers
109
+ self.decoder = self.convert_to_rbln_conditional_generation(model)
94
110
 
95
111
  def convert_to_rbln_conditional_generation(self, model: nn.Module):
96
112
  new_layers = []
@@ -105,13 +121,21 @@ class WhisperDecoderWrapper(torch.nn.Module):
105
121
 
106
122
  def forward(
107
123
  self,
108
- decoder_input_ids: torch.Tensor,
109
- decoder_attention_mask: torch.Tensor,
110
- cache_position: torch.Tensor,
111
- block_tables: torch.Tensor,
112
- cross_kv_cache: torch.Tensor,
113
- *self_kv_cache: torch.Tensor,
124
+ *args,
114
125
  ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
126
+ if self.use_attention_mask:
127
+ (
128
+ decoder_input_ids,
129
+ decoder_attention_mask,
130
+ cache_position,
131
+ block_tables,
132
+ cross_kv_cache,
133
+ *self_kv_cache,
134
+ ) = args
135
+ else:
136
+ decoder_attention_mask = None
137
+ (decoder_input_ids, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
138
+
115
139
  # prepare past_key_values
116
140
  self_past_key_values = ()
117
141
  cross_past_key_values = ()
@@ -163,11 +187,18 @@ class WhisperDecoder(nn.Module):
163
187
 
164
188
  # positional embeding
165
189
  inputs_embeds = self.embed_tokens(input_ids)
166
- positions = self.embed_positions(input_ids, position_ids=cache_position)
167
- hidden_states = inputs_embeds + positions
190
+ all_hiddens = []
191
+ for i in range(inputs_embeds.shape[0]):
192
+ position_id = cache_position[i]
193
+ position = self.embed_positions(input_ids, position_ids=position_id)
194
+ batch_hidden = position + inputs_embeds[i]
195
+ all_hiddens.append(batch_hidden)
168
196
 
169
- # prepare casual_attn_mask
170
- attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
197
+ hidden_states = torch.stack(all_hiddens, dim=0)
198
+
199
+ # prepare attn mask (normal attention - masked)
200
+ if attention_mask is not None:
201
+ attention_mask = attention_mask[:, None, None, :]
171
202
 
172
203
  cross_attentions = ()
173
204
  # iterate decoder_layer
@@ -279,18 +310,22 @@ class WhisperSelfAttention(WhisperAttention):
279
310
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
280
311
  block_size = past_key_value[0].shape[-2]
281
312
 
282
- attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
313
+ args = [
283
314
  query_states,
284
315
  key_states,
285
316
  value_states,
286
- attention_mask.unsqueeze(2),
287
317
  past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
288
318
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
289
- cache_position.expand(bsz, 1),
319
+ cache_position,
290
320
  torch.tensor(1.0, dtype=torch.float32), # scale
291
321
  block_tables,
292
322
  block_size,
293
- )
323
+ ]
324
+ if attention_mask is not None:
325
+ args.insert(3, attention_mask.unsqueeze(2))
326
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(*args)
327
+ else:
328
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(*args)
294
329
 
295
330
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
296
331
  attn_output = attn_output.transpose(1, 2)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.4a0
3
+ Version: 0.7.4a1
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -25,10 +25,10 @@ Requires-Python: <3.13,>=3.9
25
25
  Requires-Dist: accelerate>=1.0.1
26
26
  Requires-Dist: diffusers<=0.31.0
27
27
  Requires-Dist: packaging>=24.1
28
- Requires-Dist: torch<=2.5.1
29
- Requires-Dist: torchaudio<=2.5.1
30
- Requires-Dist: torchvision<=0.20.1
31
- Requires-Dist: transformers==4.48.3
28
+ Requires-Dist: torch==2.6.0
29
+ Requires-Dist: torchaudio<=2.6.0
30
+ Requires-Dist: torchvision<=0.21.0
31
+ Requires-Dist: transformers==4.50.3
32
32
  Description-Content-Type: text/markdown
33
33
 
34
34
 
@@ -1,5 +1,5 @@
1
- optimum/rbln/__init__.py,sha256=ZDzXcl-oAcYJhKjJMpotjbTih9awo7HzUb6T3MUEP6Q,6894
2
- optimum/rbln/__version__.py,sha256=xyj1Oj5eR1yz0oBU9FRdubMKrBiNrPrrW8h8ohd1iG8,513
1
+ optimum/rbln/__init__.py,sha256=qW45z47BiNLTDtRFEhVEzr4THNFX0ygqCbdNKqI0biI,6992
2
+ optimum/rbln/__version__.py,sha256=KifVR95YmJmHh5f74wGiEAzd-c6ElHQ3XFHbY8VRp14,513
3
3
  optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
4
4
  optimum/rbln/modeling_base.py,sha256=dNCL-BhrWCpuOVkZaj8-MW567Tf4lLo3p3Z3ldjWJfU,21779
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
@@ -41,28 +41,29 @@ optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py,sha256=9iIMZYvp
41
41
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py,sha256=OvB5bxX6HUiqJeIc3uukuEmUXYEx1pTqGNOtdG2l1m8,902
42
42
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=3aB1Rw-OgKytQOHwOaShbEvq_XVHPOGvsGm8pstEmKU,930
43
43
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=MzVP1wscaO1sUIiBIPJqG6zuGyez9VUbA42-JSIm-mk,930
44
- optimum/rbln/ops/__init__.py,sha256=Wv2cJhEw8mqc6-To24bHzf4qQL8gM0Zh_2Ck77LB65g,947
44
+ optimum/rbln/ops/__init__.py,sha256=LmTIX9yTfRiMDcalmb52yz5LhLRWqq3H5S94r0VDYDw,974
45
45
  optimum/rbln/ops/attn.py,sha256=OSgPoEgCwvR7HdjbnaVkFVMBcJ5RpRWcE6OCg2lVyGk,10634
46
46
  optimum/rbln/ops/flash_attn.py,sha256=wfyiCxDGf034IngzwRU160R7_DlKYpd-uWT0BDEGFks,3408
47
47
  optimum/rbln/ops/kv_cache_update.py,sha256=pxf8kAptPaQF5xE8qItvmlFOq_sgim6ZERD7AVaOtec,3221
48
- optimum/rbln/transformers/__init__.py,sha256=AGo3BqVIZrsOzYsQAnnQ25HCstTPBclrXbvvUxVMlqE,4255
48
+ optimum/rbln/ops/linear.py,sha256=1_7Hg-9wXxhu97fqPobotLQx17k7VPeSSL91_9Z7EDg,1018
49
+ optimum/rbln/transformers/__init__.py,sha256=rW2wEgNpkcBwrrib2tui5sEpw04s1YUDHB50m2L7Os8,4353
49
50
  optimum/rbln/transformers/modeling_alias.py,sha256=yx7FnZQWAnrWzivaO5hI7T6i-fyLzt2tMIXG2oDNbPo,1657
50
51
  optimum/rbln/transformers/modeling_generic.py,sha256=aaZWsqVDCRvH03q-Wen7DMfLr7Gy-u-I0mTw0aYqWjk,18195
51
52
  optimum/rbln/transformers/modeling_rope_utils.py,sha256=3zwkhYUyTZhxCJUSmwCc88iiY1TppRWEY9ShwUqNB2k,14293
52
- optimum/rbln/transformers/models/__init__.py,sha256=zGnYODR-_T65tv6jFjtC8l01LC4vjfm41bM4doCXRvY,3835
53
+ optimum/rbln/transformers/models/__init__.py,sha256=Qyt9E61FDpnyAXTmRKDbv7CTtn-ml9cITvvNVqhwrnA,3992
53
54
  optimum/rbln/transformers/models/auto/__init__.py,sha256=GvGbb3ZpMv-h6euXeZ42jSizoOfrL2O1uvpAnfKxYEo,1034
54
55
  optimum/rbln/transformers/models/auto/auto_factory.py,sha256=IK9jFrJ3EEzYQa9_aKpcp2TO68M5YGkA-HcfBVpA2QU,7027
55
56
  optimum/rbln/transformers/models/auto/modeling_auto.py,sha256=Un9qoqdy3dO8JBza_bTJF_6_fRVNM9QisihSgTRFI-o,3933
56
57
  optimum/rbln/transformers/models/bart/__init__.py,sha256=32HPe0_GIO0hp9U464Iv6Jd7M-1nop9g8hA1UZMHhyw,674
57
58
  optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=Oo-Cdne7igKEex8wwP-gztKJHgs5GLHQjK1oc3IZIDE,5801
58
- optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=CUF5PE9TxJxtO1VpuGgeKrL_u6PdsKxstlZDthYSXgU,5829
59
+ optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=naFpsOSjNRG8s5QPjeAsYCk2oJCxnn0Au0aYnMKZOBY,5679
59
60
  optimum/rbln/transformers/models/bert/__init__.py,sha256=YVV7k_laU6yJBawZrgjIWjRmIF-Y4oQQHqyf8lsraQs,691
60
61
  optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaMCa1EfXttNJkqCJUIZo3CeZ9YY,4674
61
62
  optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
62
63
  optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
63
64
  optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
64
65
  optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=m93-qKN7NMw3i0XDmFmttmRIRK4np_fWtLFlBb2RFgU,41351
65
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=uGdPGcFrWm2gAwFLjfBiALwFsl49VGCReVi4NUfOPxM,38898
66
+ optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=qeZWdfLU0gCssxBODJsjQWMjfQWxK9vgC2Xt9eA5j4I,39147
66
67
  optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
67
68
  optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
68
69
  optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
@@ -92,17 +93,20 @@ optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxP
92
93
  optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
93
94
  optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
94
95
  optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
95
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=QelhuCWEHPL2Ut7fm0gLnzTVveBAaKSNpoa9X1AmwTI,17709
96
+ optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=XcZb57v42wju1qOJ1AKqmtJXcmz6MEWaJZ8jyzaEiTw,17701
96
97
  optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
97
98
  optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
98
99
  optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=-fG-h0wwsfjZ3par0QHbXKA7hbvw_lPJOIf8iXQDOfM,8082
99
100
  optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=Ups6drBbYe4wEAiBLcBIyO9wqrIQbvOPFR_ybbAgR8c,9722
101
+ optimum/rbln/transformers/models/time_series_transformers/__init__.py,sha256=RL4SO8tKEd4wQrzyU4Nv4-hhITKPhblUsBd3anXNkA8,1079
102
+ optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py,sha256=1Ippt0Rmt2TxJ5X4-4tlALQOkKmOfMaTrbOLWIUIKWw,16614
103
+ optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py,sha256=ohoP4sAxyQZwrQ6euGfRx9w_pPWAh6KT9nKC8Y9taes,14006
100
104
  optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
101
105
  optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
102
106
  optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
103
107
  optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
104
- optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=U9zK49DcSdXuoK_UOsVPsyKe6EJ5CQR8QZhpgi23EUU,16275
105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=ArQPOgiRVu-XddEN5FXVl1OlCoGF6uY7jGoWTj3Nfe4,13005
108
+ optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=GegyAi3a8fF0psdYsffTQ1pC4KAUqE7WYLj4ZqObWXI,18184
109
+ optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=k_aDk2B58IxQimf6yW36Wgc0uw5PqB85Or8ie_6ZZ70,14205
106
110
  optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
107
111
  optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
108
112
  optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -116,7 +120,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
116
120
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
117
121
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
118
122
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
119
- optimum_rbln-0.7.4a0.dist-info/METADATA,sha256=tXU0EmgjFJug_Cvmw8S9NeEZ2z9XpgamFwgMQTTCa1U,5300
120
- optimum_rbln-0.7.4a0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
121
- optimum_rbln-0.7.4a0.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
122
- optimum_rbln-0.7.4a0.dist-info/RECORD,,
123
+ optimum_rbln-0.7.4a1.dist-info/METADATA,sha256=dMl4yloIz6iqjC2SN8CE1rVP9Kftw50Z01zocntnguE,5300
124
+ optimum_rbln-0.7.4a1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
125
+ optimum_rbln-0.7.4a1.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
126
+ optimum_rbln-0.7.4a1.dist-info/RECORD,,