optimum-rbln 0.7.3a2__py3-none-any.whl → 0.7.3a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.3a2'
21
- __version_tuple__ = version_tuple = (0, 7, 3)
20
+ __version__ = version = '0.7.3a3'
21
+ __version_tuple__ = version_tuple = (0, 7, 3, 'a3')
@@ -54,7 +54,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
54
54
  block_tables: torch.Tensor,
55
55
  free_block_pool: Deque,
56
56
  kvcache_block_size: int,
57
- kvcache_num_blocks: int,
58
57
  use_attention_mask: bool,
59
58
  attn_impl: str,
60
59
  **kwargs: Any,
@@ -72,7 +71,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
72
71
  self.free_block_pool = free_block_pool
73
72
 
74
73
  self.kvcache_block_size = kvcache_block_size
75
- self.empty_block = kvcache_num_blocks - 1
74
+ self.empty_block = -1
76
75
  self.attn_impl = attn_impl
77
76
 
78
77
  if self.phase == "prefill":
@@ -97,58 +96,61 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
97
96
  torch.Tensor: Updated block tables.
98
97
  """
99
98
 
100
- def update_block(batch_idx, block_idx):
99
+ NO_BLOCKS_ERROR = (
100
+ "No memory blocks are available for allocation."
101
+ "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln."
102
+ "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html)."
103
+ "Using vllm-rbln should fix this issue and enhance inference performance."
104
+ )
105
+
106
+ def update_block(batch_idx: int, block_idx: int):
101
107
  """
102
- Helper function to update the block table for a given batch index and block index.
103
108
  If the block is empty (empty_block), allocates a block from the free_block_pool.
104
-
105
- Args:
106
- batch_idx (int): Batch index.
107
- block_idx (int): Block index.
108
-
109
- Raises:
110
- RuntimeError: Raised if no available blocks are found in the free_block_pool.
111
109
  """
112
110
  if self.block_tables[batch_idx][block_idx] == self.empty_block:
113
111
  if self.free_block_pool:
114
112
  block = self.free_block_pool.popleft()
115
113
  self.block_tables[batch_idx][block_idx] = block
116
114
  else:
117
- raise RuntimeError("Not available blocks")
115
+ raise RuntimeError(NO_BLOCKS_ERROR)
118
116
 
119
- if self.attn_impl == "eager":
120
- if self.phase == "prefill":
121
- return self.block_tables[batch_idx]
117
+ def replace_empty_block(block_tables: torch.Tensor):
118
+ """
119
+ Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
120
+ """
121
+ if not torch.any(block_tables == self.empty_block):
122
+ return block_tables.clone()
123
+ elif self.free_block_pool:
124
+ _free_block = self.free_block_pool[0]
125
+ return torch.where(block_tables == self.empty_block, _free_block, block_tables)
122
126
  else:
123
- return self.block_tables
124
- # Case for 'flash_attn' attention implementation
127
+ raise RuntimeError(NO_BLOCKS_ERROR)
128
+
129
+ if self.phase == "prefill":
130
+ # Track previously used blocks and return them to the free_block_pool and
131
+ # reset the current batch's block table to empty blocks
132
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
133
+ self.free_block_pool.extend(prev_blocks)
134
+ self.block_tables[batch_idx].fill_(self.empty_block)
135
+
136
+ # Get the start (s) and end (e) positions from cache_position and
137
+ # iterate over the cache positions to allocate necessary blocks
138
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
139
+ for position in range(s, e + 1, self.kvcache_block_size):
140
+ block_idx = position // self.kvcache_block_size
141
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
142
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
143
+ update_block(batch_idx, block_idx)
144
+
145
+ return replace_empty_block(self.block_tables[batch_idx])
146
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
125
147
  else:
126
- if self.phase == "prefill":
127
- # Track previously used blocks and return them to the free_block_pool and
128
- # reset the current batch's block table to empty blocks
129
- prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
130
- self.free_block_pool.extend(prev_blocks)
131
- self.block_tables[batch_idx].fill_(self.empty_block)
132
-
133
- # Get the start (s) and end (e) positions from cache_position and
134
- # iterate over the cache positions to allocate necessary blocks
135
- s, e = cache_position[0][0].item(), cache_position[0][-1].item()
136
- for position in range(s, e + 1, self.kvcache_block_size):
137
- block_idx = position // self.kvcache_block_size
138
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
139
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
140
- update_block(batch_idx, block_idx)
141
-
142
- return self.block_tables[batch_idx]
143
-
144
- # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
145
- else:
146
- for b_idx in range(self.batch_size):
147
- position = cache_position[b_idx][0].item()
148
- block_idx = position // self.kvcache_block_size
149
- update_block(b_idx, block_idx)
148
+ for b_idx in range(self.batch_size):
149
+ position = cache_position[b_idx][0].item()
150
+ block_idx = position // self.kvcache_block_size
151
+ update_block(b_idx, block_idx)
150
152
 
151
- return self.block_tables
153
+ return replace_empty_block(self.block_tables)
152
154
 
153
155
  def forward(
154
156
  self,
@@ -380,14 +382,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
380
382
 
381
383
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
382
384
  dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
383
- if attn_impl == "eager":
384
- block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).reshape(self.batch_size, 1)
385
- free_block_pool = None
386
- else:
387
- block_tables = torch.zeros(
388
- self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
389
- ).fill_(self.kvcache_num_blocks - 1)
390
- free_block_pool = deque(x for x in range(self.kvcache_num_blocks - 1))
385
+ block_tables = torch.zeros(
386
+ self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
387
+ ).fill_(-1)
388
+ free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
391
389
 
392
390
  self.prefill_decoder = RBLNRuntimeModel(
393
391
  runtime=self.model[0],
@@ -399,7 +397,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
399
397
  block_tables=block_tables,
400
398
  free_block_pool=free_block_pool,
401
399
  kvcache_block_size=self.kvcache_block_size,
402
- kvcache_num_blocks=self.kvcache_num_blocks,
403
400
  vocab_size=self.config.vocab_size,
404
401
  prefill_chunk_size=self.prefill_chunk_size,
405
402
  max_seq_len=self.max_seq_len,
@@ -416,7 +413,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
416
413
  block_tables=block_tables,
417
414
  free_block_pool=free_block_pool,
418
415
  kvcache_block_size=self.kvcache_block_size,
419
- kvcache_num_blocks=self.kvcache_num_blocks,
420
416
  use_attention_mask=self.use_attention_mask,
421
417
  attn_impl=attn_impl,
422
418
  )
@@ -25,7 +25,7 @@ from transformers.modeling_outputs import (
25
25
  )
26
26
  from transformers.utils import logging
27
27
 
28
- from ....ops import register_rbln_custom_cache_update
28
+ from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
29
29
 
30
30
 
31
31
  logger = logging.get_logger(__name__)
@@ -34,6 +34,7 @@ logger = logging.get_logger(__name__)
34
34
  class WhisperWrapper:
35
35
  def __init__(self, model, rbln_token_timestamps):
36
36
  register_rbln_custom_cache_update()
37
+ register_rbln_custom_add_softmax_attention()
37
38
  self.encoder = WhisperEncoderWrapper(model)
38
39
  self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
39
40
 
@@ -213,7 +214,7 @@ class WhisperDecoderLayer(nn.Module):
213
214
  # Self Attention Block
214
215
  residual = hidden_states
215
216
  hidden_states = self.self_attn_layer_norm(hidden_states)
216
- hidden_states, _, self_present_key_value = self.self_attn(
217
+ hidden_states, self_present_key_value = self.self_attn(
217
218
  hidden_states=hidden_states,
218
219
  past_key_value=self_past_key_value,
219
220
  attention_mask=attention_mask,
@@ -224,7 +225,7 @@ class WhisperDecoderLayer(nn.Module):
224
225
  # Cross-Attention Block
225
226
  residual = hidden_states
226
227
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
227
- hidden_states, cross_attn_weights, cross_present_key_value = self.encoder_attn(
228
+ hidden_states, cross_attn_weights = self.encoder_attn(
228
229
  hidden_states=hidden_states,
229
230
  past_key_value=cross_past_key_value,
230
231
  )
@@ -258,19 +259,8 @@ class WhisperAttention(nn.Module):
258
259
 
259
260
 
260
261
  class WhisperSelfAttention(WhisperAttention):
261
- def rbln_cache_update(
262
- self,
263
- past_key_value: torch.Tensor,
264
- key_states: torch.Tensor,
265
- value_states: torch.Tensor,
266
- cache_position: torch.Tensor,
267
- ):
268
- s_idx = torch.tensor(cache_position, dtype=torch.int16)
269
- axis = torch.tensor(2, dtype=torch.int16)
270
-
271
- key_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[0], key_states, s_idx, axis)
272
- value_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[1], value_states, s_idx, axis)
273
- return key_states, value_states
262
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
263
+ return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
274
264
 
275
265
  def forward(
276
266
  self,
@@ -285,22 +275,27 @@ class WhisperSelfAttention(WhisperAttention):
285
275
 
286
276
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
287
277
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
288
- key_states, value_states = self.rbln_cache_update(past_key_value, key_states, value_states, cache_position)
289
278
 
290
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
291
- attn_weights = attn_weights + attention_mask
292
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
279
+ attn_output, key_states, value_states = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
280
+ query_states,
281
+ key_states,
282
+ value_states,
283
+ attention_mask.unsqueeze(2),
284
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
285
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
286
+ cache_position.expand(bsz, 1),
287
+ torch.tensor(1.0, dtype=torch.float32), # scale
288
+ )
293
289
 
294
- attn_output = torch.matmul(attn_weights, value_states)
295
290
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
296
291
  attn_output = attn_output.transpose(1, 2)
297
292
  attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
298
293
  attn_output = self.out_proj(attn_output)
299
294
 
300
- return attn_output, attn_weights, (key_states, value_states)
295
+ return attn_output, (key_states, value_states)
301
296
 
302
297
 
303
- class WhisperCrossAttention(WhisperSelfAttention):
298
+ class WhisperCrossAttention(WhisperAttention):
304
299
  def forward(
305
300
  self,
306
301
  hidden_states: torch.Tensor,
@@ -322,4 +317,4 @@ class WhisperCrossAttention(WhisperSelfAttention):
322
317
  attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
323
318
  attn_output = self.out_proj(attn_output)
324
319
 
325
- return attn_output, attn_weights, (key_states, value_states)
320
+ return attn_output, attn_weights
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.3a2
3
+ Version: 0.7.3a3
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,5 +1,5 @@
1
1
  optimum/rbln/__init__.py,sha256=eHi15YM3989AcX52jka9rUmgAtlp1PHqMNwBEdOfuu8,6554
2
- optimum/rbln/__version__.py,sha256=bShBukYvw7AqWtLsut0yClygDEGsFRmxrXypqIeEXcQ,513
2
+ optimum/rbln/__version__.py,sha256=jlkAV1bws10Tgk9b3JF90gq1GOekHphDutCCDtjNFJc,519
3
3
  optimum/rbln/modeling.py,sha256=3XE0IrCYbkjw9_Q9BFzZ_ri_Kyxw1g6iwfdohZB46-s,8289
4
4
  optimum/rbln/modeling_base.py,sha256=ELSPbjx7awBRM2SckkD-5gI3TIa01mfzz7gDRC1Pljk,21778
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
@@ -60,7 +60,7 @@ optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-u
60
60
  optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
61
61
  optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
62
62
  optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=x8_xQ5aGXbadJyajpJQyfgxx4YPSj62VlmmGDMnC-1E,41819
63
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=dyl8tDBjfe5VfU1XbKAoZS7g7F90JTYVmMuz0HTmCoE,35345
63
+ optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=DylKxV1kFbDv34txpuI5JrvMcSTa2W910eO9dmF0o_8,35352
64
64
  optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
65
65
  optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
66
66
  optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
@@ -100,7 +100,7 @@ optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6c
100
100
  optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
101
101
  optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
102
102
  optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=0nBADNxE0A1ozBbRutTBvxpo_Y1qkOycT_zronkN-ZU,15840
103
- optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=eP3UgkwCRaaFjc5Jc4ZEiWxr3-L7oJx9KzpJ7eFkwUs,13158
103
+ optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=_6PmE4-DD5QhohQwHW5M11q_L9f_ayF6StmNTlOYJdg,12896
104
104
  optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
105
105
  optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
106
106
  optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -114,7 +114,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
114
114
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
115
115
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
116
116
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
117
- optimum_rbln-0.7.3a2.dist-info/METADATA,sha256=C-IWumO-veJFZPHpF8wcOTOE0TCDzKU1Xk_ylaqrvPM,5300
118
- optimum_rbln-0.7.3a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
119
- optimum_rbln-0.7.3a2.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
120
- optimum_rbln-0.7.3a2.dist-info/RECORD,,
117
+ optimum_rbln-0.7.3a3.dist-info/METADATA,sha256=UQs6c3GdXbPYE8wSnT6Ca9TtgfKwEgPNVZk-MoAKQPc,5300
118
+ optimum_rbln-0.7.3a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
119
+ optimum_rbln-0.7.3a3.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
120
+ optimum_rbln-0.7.3a3.dist-info/RECORD,,