optimum-rbln 0.7.3a2__py3-none-any.whl → 0.7.3a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +48 -52
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +19 -24
- {optimum_rbln-0.7.3a2.dist-info → optimum_rbln-0.7.3a3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3a2.dist-info → optimum_rbln-0.7.3a3.dist-info}/RECORD +7 -7
- {optimum_rbln-0.7.3a2.dist-info → optimum_rbln-0.7.3a3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3a2.dist-info → optimum_rbln-0.7.3a3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.7.
|
21
|
-
__version_tuple__ = version_tuple = (0, 7, 3)
|
20
|
+
__version__ = version = '0.7.3a3'
|
21
|
+
__version_tuple__ = version_tuple = (0, 7, 3, 'a3')
|
@@ -54,7 +54,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
54
54
|
block_tables: torch.Tensor,
|
55
55
|
free_block_pool: Deque,
|
56
56
|
kvcache_block_size: int,
|
57
|
-
kvcache_num_blocks: int,
|
58
57
|
use_attention_mask: bool,
|
59
58
|
attn_impl: str,
|
60
59
|
**kwargs: Any,
|
@@ -72,7 +71,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
72
71
|
self.free_block_pool = free_block_pool
|
73
72
|
|
74
73
|
self.kvcache_block_size = kvcache_block_size
|
75
|
-
self.empty_block =
|
74
|
+
self.empty_block = -1
|
76
75
|
self.attn_impl = attn_impl
|
77
76
|
|
78
77
|
if self.phase == "prefill":
|
@@ -97,58 +96,61 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
97
96
|
torch.Tensor: Updated block tables.
|
98
97
|
"""
|
99
98
|
|
100
|
-
|
99
|
+
NO_BLOCKS_ERROR = (
|
100
|
+
"No memory blocks are available for allocation."
|
101
|
+
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln."
|
102
|
+
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html)."
|
103
|
+
"Using vllm-rbln should fix this issue and enhance inference performance."
|
104
|
+
)
|
105
|
+
|
106
|
+
def update_block(batch_idx: int, block_idx: int):
|
101
107
|
"""
|
102
|
-
Helper function to update the block table for a given batch index and block index.
|
103
108
|
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
104
|
-
|
105
|
-
Args:
|
106
|
-
batch_idx (int): Batch index.
|
107
|
-
block_idx (int): Block index.
|
108
|
-
|
109
|
-
Raises:
|
110
|
-
RuntimeError: Raised if no available blocks are found in the free_block_pool.
|
111
109
|
"""
|
112
110
|
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
113
111
|
if self.free_block_pool:
|
114
112
|
block = self.free_block_pool.popleft()
|
115
113
|
self.block_tables[batch_idx][block_idx] = block
|
116
114
|
else:
|
117
|
-
raise RuntimeError(
|
115
|
+
raise RuntimeError(NO_BLOCKS_ERROR)
|
118
116
|
|
119
|
-
|
120
|
-
|
121
|
-
|
117
|
+
def replace_empty_block(block_tables: torch.Tensor):
|
118
|
+
"""
|
119
|
+
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
120
|
+
"""
|
121
|
+
if not torch.any(block_tables == self.empty_block):
|
122
|
+
return block_tables.clone()
|
123
|
+
elif self.free_block_pool:
|
124
|
+
_free_block = self.free_block_pool[0]
|
125
|
+
return torch.where(block_tables == self.empty_block, _free_block, block_tables)
|
122
126
|
else:
|
123
|
-
|
124
|
-
|
127
|
+
raise RuntimeError(NO_BLOCKS_ERROR)
|
128
|
+
|
129
|
+
if self.phase == "prefill":
|
130
|
+
# Track previously used blocks and return them to the free_block_pool and
|
131
|
+
# reset the current batch's block table to empty blocks
|
132
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
133
|
+
self.free_block_pool.extend(prev_blocks)
|
134
|
+
self.block_tables[batch_idx].fill_(self.empty_block)
|
135
|
+
|
136
|
+
# Get the start (s) and end (e) positions from cache_position and
|
137
|
+
# iterate over the cache positions to allocate necessary blocks
|
138
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
139
|
+
for position in range(s, e + 1, self.kvcache_block_size):
|
140
|
+
block_idx = position // self.kvcache_block_size
|
141
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
142
|
+
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
143
|
+
update_block(batch_idx, block_idx)
|
144
|
+
|
145
|
+
return replace_empty_block(self.block_tables[batch_idx])
|
146
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
125
147
|
else:
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
self.free_block_pool.extend(prev_blocks)
|
131
|
-
self.block_tables[batch_idx].fill_(self.empty_block)
|
132
|
-
|
133
|
-
# Get the start (s) and end (e) positions from cache_position and
|
134
|
-
# iterate over the cache positions to allocate necessary blocks
|
135
|
-
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
136
|
-
for position in range(s, e + 1, self.kvcache_block_size):
|
137
|
-
block_idx = position // self.kvcache_block_size
|
138
|
-
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
139
|
-
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
140
|
-
update_block(batch_idx, block_idx)
|
141
|
-
|
142
|
-
return self.block_tables[batch_idx]
|
143
|
-
|
144
|
-
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
145
|
-
else:
|
146
|
-
for b_idx in range(self.batch_size):
|
147
|
-
position = cache_position[b_idx][0].item()
|
148
|
-
block_idx = position // self.kvcache_block_size
|
149
|
-
update_block(b_idx, block_idx)
|
148
|
+
for b_idx in range(self.batch_size):
|
149
|
+
position = cache_position[b_idx][0].item()
|
150
|
+
block_idx = position // self.kvcache_block_size
|
151
|
+
update_block(b_idx, block_idx)
|
150
152
|
|
151
|
-
|
153
|
+
return replace_empty_block(self.block_tables)
|
152
154
|
|
153
155
|
def forward(
|
154
156
|
self,
|
@@ -380,14 +382,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
380
382
|
|
381
383
|
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
382
384
|
dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
block_tables = torch.zeros(
|
388
|
-
self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
|
389
|
-
).fill_(self.kvcache_num_blocks - 1)
|
390
|
-
free_block_pool = deque(x for x in range(self.kvcache_num_blocks - 1))
|
385
|
+
block_tables = torch.zeros(
|
386
|
+
self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
|
387
|
+
).fill_(-1)
|
388
|
+
free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
|
391
389
|
|
392
390
|
self.prefill_decoder = RBLNRuntimeModel(
|
393
391
|
runtime=self.model[0],
|
@@ -399,7 +397,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
399
397
|
block_tables=block_tables,
|
400
398
|
free_block_pool=free_block_pool,
|
401
399
|
kvcache_block_size=self.kvcache_block_size,
|
402
|
-
kvcache_num_blocks=self.kvcache_num_blocks,
|
403
400
|
vocab_size=self.config.vocab_size,
|
404
401
|
prefill_chunk_size=self.prefill_chunk_size,
|
405
402
|
max_seq_len=self.max_seq_len,
|
@@ -416,7 +413,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
416
413
|
block_tables=block_tables,
|
417
414
|
free_block_pool=free_block_pool,
|
418
415
|
kvcache_block_size=self.kvcache_block_size,
|
419
|
-
kvcache_num_blocks=self.kvcache_num_blocks,
|
420
416
|
use_attention_mask=self.use_attention_mask,
|
421
417
|
attn_impl=attn_impl,
|
422
418
|
)
|
@@ -25,7 +25,7 @@ from transformers.modeling_outputs import (
|
|
25
25
|
)
|
26
26
|
from transformers.utils import logging
|
27
27
|
|
28
|
-
from ....ops import register_rbln_custom_cache_update
|
28
|
+
from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
|
29
29
|
|
30
30
|
|
31
31
|
logger = logging.get_logger(__name__)
|
@@ -34,6 +34,7 @@ logger = logging.get_logger(__name__)
|
|
34
34
|
class WhisperWrapper:
|
35
35
|
def __init__(self, model, rbln_token_timestamps):
|
36
36
|
register_rbln_custom_cache_update()
|
37
|
+
register_rbln_custom_add_softmax_attention()
|
37
38
|
self.encoder = WhisperEncoderWrapper(model)
|
38
39
|
self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
|
39
40
|
|
@@ -213,7 +214,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
213
214
|
# Self Attention Block
|
214
215
|
residual = hidden_states
|
215
216
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
216
|
-
hidden_states,
|
217
|
+
hidden_states, self_present_key_value = self.self_attn(
|
217
218
|
hidden_states=hidden_states,
|
218
219
|
past_key_value=self_past_key_value,
|
219
220
|
attention_mask=attention_mask,
|
@@ -224,7 +225,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
224
225
|
# Cross-Attention Block
|
225
226
|
residual = hidden_states
|
226
227
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
227
|
-
hidden_states, cross_attn_weights
|
228
|
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
228
229
|
hidden_states=hidden_states,
|
229
230
|
past_key_value=cross_past_key_value,
|
230
231
|
)
|
@@ -258,19 +259,8 @@ class WhisperAttention(nn.Module):
|
|
258
259
|
|
259
260
|
|
260
261
|
class WhisperSelfAttention(WhisperAttention):
|
261
|
-
def
|
262
|
-
self,
|
263
|
-
past_key_value: torch.Tensor,
|
264
|
-
key_states: torch.Tensor,
|
265
|
-
value_states: torch.Tensor,
|
266
|
-
cache_position: torch.Tensor,
|
267
|
-
):
|
268
|
-
s_idx = torch.tensor(cache_position, dtype=torch.int16)
|
269
|
-
axis = torch.tensor(2, dtype=torch.int16)
|
270
|
-
|
271
|
-
key_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[0], key_states, s_idx, axis)
|
272
|
-
value_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[1], value_states, s_idx, axis)
|
273
|
-
return key_states, value_states
|
262
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
263
|
+
return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
274
264
|
|
275
265
|
def forward(
|
276
266
|
self,
|
@@ -285,22 +275,27 @@ class WhisperSelfAttention(WhisperAttention):
|
|
285
275
|
|
286
276
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
287
277
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
288
|
-
key_states, value_states = self.rbln_cache_update(past_key_value, key_states, value_states, cache_position)
|
289
278
|
|
290
|
-
|
291
|
-
|
292
|
-
|
279
|
+
attn_output, key_states, value_states = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
|
280
|
+
query_states,
|
281
|
+
key_states,
|
282
|
+
value_states,
|
283
|
+
attention_mask.unsqueeze(2),
|
284
|
+
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
285
|
+
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
286
|
+
cache_position.expand(bsz, 1),
|
287
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
288
|
+
)
|
293
289
|
|
294
|
-
attn_output = torch.matmul(attn_weights, value_states)
|
295
290
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
296
291
|
attn_output = attn_output.transpose(1, 2)
|
297
292
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
298
293
|
attn_output = self.out_proj(attn_output)
|
299
294
|
|
300
|
-
return attn_output,
|
295
|
+
return attn_output, (key_states, value_states)
|
301
296
|
|
302
297
|
|
303
|
-
class WhisperCrossAttention(
|
298
|
+
class WhisperCrossAttention(WhisperAttention):
|
304
299
|
def forward(
|
305
300
|
self,
|
306
301
|
hidden_states: torch.Tensor,
|
@@ -322,4 +317,4 @@ class WhisperCrossAttention(WhisperSelfAttention):
|
|
322
317
|
attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
|
323
318
|
attn_output = self.out_proj(attn_output)
|
324
319
|
|
325
|
-
return attn_output, attn_weights
|
320
|
+
return attn_output, attn_weights
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.3a3
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|
@@ -1,5 +1,5 @@
|
|
1
1
|
optimum/rbln/__init__.py,sha256=eHi15YM3989AcX52jka9rUmgAtlp1PHqMNwBEdOfuu8,6554
|
2
|
-
optimum/rbln/__version__.py,sha256=
|
2
|
+
optimum/rbln/__version__.py,sha256=jlkAV1bws10Tgk9b3JF90gq1GOekHphDutCCDtjNFJc,519
|
3
3
|
optimum/rbln/modeling.py,sha256=3XE0IrCYbkjw9_Q9BFzZ_ri_Kyxw1g6iwfdohZB46-s,8289
|
4
4
|
optimum/rbln/modeling_base.py,sha256=ELSPbjx7awBRM2SckkD-5gI3TIa01mfzz7gDRC1Pljk,21778
|
5
5
|
optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
|
@@ -60,7 +60,7 @@ optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-u
|
|
60
60
|
optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
|
61
61
|
optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
|
62
62
|
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=x8_xQ5aGXbadJyajpJQyfgxx4YPSj62VlmmGDMnC-1E,41819
|
63
|
-
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=
|
63
|
+
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=DylKxV1kFbDv34txpuI5JrvMcSTa2W910eO9dmF0o_8,35352
|
64
64
|
optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
|
65
65
|
optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
|
66
66
|
optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
|
@@ -100,7 +100,7 @@ optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6c
|
|
100
100
|
optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
|
101
101
|
optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
|
102
102
|
optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=0nBADNxE0A1ozBbRutTBvxpo_Y1qkOycT_zronkN-ZU,15840
|
103
|
-
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=
|
103
|
+
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=_6PmE4-DD5QhohQwHW5M11q_L9f_ayF6StmNTlOYJdg,12896
|
104
104
|
optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
|
105
105
|
optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
|
106
106
|
optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -114,7 +114,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
|
|
114
114
|
optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
|
115
115
|
optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
|
116
116
|
optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
|
117
|
-
optimum_rbln-0.7.
|
118
|
-
optimum_rbln-0.7.
|
119
|
-
optimum_rbln-0.7.
|
120
|
-
optimum_rbln-0.7.
|
117
|
+
optimum_rbln-0.7.3a3.dist-info/METADATA,sha256=UQs6c3GdXbPYE8wSnT6Ca9TtgfKwEgPNVZk-MoAKQPc,5300
|
118
|
+
optimum_rbln-0.7.3a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
119
|
+
optimum_rbln-0.7.3a3.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
120
|
+
optimum_rbln-0.7.3a3.dist-info/RECORD,,
|
File without changes
|
File without changes
|