optimum-rbln 0.7.3a3__py3-none-any.whl → 0.7.3a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/modeling.py +3 -0
- optimum/rbln/modeling_base.py +12 -0
- optimum/rbln/ops/attn.py +20 -59
- optimum/rbln/ops/flash_attn.py +12 -28
- optimum/rbln/ops/kv_cache_update.py +4 -4
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +24 -41
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +91 -2
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -19
- optimum/rbln/transformers/models/t5/t5_architecture.py +4 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +10 -15
- {optimum_rbln-0.7.3a3.dist-info → optimum_rbln-0.7.3a4.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3a3.dist-info → optimum_rbln-0.7.3a4.dist-info}/RECORD +15 -15
- {optimum_rbln-0.7.3a3.dist-info → optimum_rbln-0.7.3a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3a3.dist-info → optimum_rbln-0.7.3a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.7.
|
21
|
-
__version_tuple__ = version_tuple = (0, 7, 3, '
|
20
|
+
__version__ = version = '0.7.3a4'
|
21
|
+
__version_tuple__ = version_tuple = (0, 7, 3, 'a4')
|
optimum/rbln/modeling.py
CHANGED
@@ -134,6 +134,9 @@ class RBLNModel(RBLNBaseModel):
|
|
134
134
|
for preprocessor in preprocessors:
|
135
135
|
preprocessor.save_pretrained(save_dir_path / subfolder)
|
136
136
|
|
137
|
+
# ad-hoc
|
138
|
+
rbln_kwargs["n_model_params"] = sum(p.numel() for p in model.parameters())
|
139
|
+
|
137
140
|
# Get compilation arguments (e.g. input_info)
|
138
141
|
rbln_config: RBLNConfig = cls.get_rbln_config(
|
139
142
|
preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
|
optimum/rbln/modeling_base.py
CHANGED
@@ -282,6 +282,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
282
282
|
**kwargs,
|
283
283
|
)
|
284
284
|
|
285
|
+
@classmethod
|
286
|
+
def _check_compiled_models(
|
287
|
+
cls, compiled_models: Dict[str, rebel.RBLNCompiledModel], rbln_config: RBLNConfig, config: "PretrainedConfig"
|
288
|
+
):
|
289
|
+
# check compiled model can create runtimes.
|
290
|
+
# this logic currently only works in LLM
|
291
|
+
# fail when LLM model using Paged Attention can't guarantee max sequence length
|
292
|
+
pass
|
293
|
+
|
285
294
|
@classmethod
|
286
295
|
def _from_compiled_models(
|
287
296
|
cls,
|
@@ -295,6 +304,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
295
304
|
):
|
296
305
|
if isinstance(model_save_dir, str):
|
297
306
|
model_save_dir = Path(model_save_dir)
|
307
|
+
|
308
|
+
cls._check_compiled_models(compiled_models=rbln_compiled_models, rbln_config=rbln_config, config=config)
|
309
|
+
|
298
310
|
# FIXME:: Should we convert it?
|
299
311
|
compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
|
300
312
|
rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
|
optimum/rbln/ops/attn.py
CHANGED
@@ -28,7 +28,7 @@ else:
|
|
28
28
|
def register_rbln_custom_paged_attention():
|
29
29
|
torch.library.define(
|
30
30
|
"rbln_custom_ops::paged_attn_decode",
|
31
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
31
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
32
32
|
)
|
33
33
|
|
34
34
|
@torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
|
@@ -57,28 +57,17 @@ def register_rbln_custom_paged_attention():
|
|
57
57
|
- block_size: [] - Number of tokens per block
|
58
58
|
|
59
59
|
Returns:
|
60
|
-
|
61
|
-
- attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
62
|
-
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
63
|
-
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
60
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
64
61
|
"""
|
65
|
-
return
|
66
|
-
q,
|
67
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
68
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
69
|
-
)
|
62
|
+
return q
|
70
63
|
|
71
64
|
@register_fake("rbln_custom_ops::paged_attn_decode")
|
72
65
|
def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
|
73
|
-
return
|
74
|
-
q,
|
75
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
76
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
77
|
-
)
|
66
|
+
return q
|
78
67
|
|
79
68
|
torch.library.define(
|
80
69
|
"rbln_custom_ops::paged_attn_prefill",
|
81
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
70
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
82
71
|
)
|
83
72
|
|
84
73
|
@torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
|
@@ -105,23 +94,20 @@ def register_rbln_custom_paged_attention():
|
|
105
94
|
- block_size: [] - Number of tokens per block
|
106
95
|
|
107
96
|
Returns:
|
108
|
-
|
109
|
-
- attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
110
|
-
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
111
|
-
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
97
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
112
98
|
"""
|
113
|
-
return q
|
99
|
+
return q
|
114
100
|
|
115
101
|
@register_fake("rbln_custom_ops::paged_attn_prefill")
|
116
102
|
def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
|
117
|
-
return q
|
103
|
+
return q
|
118
104
|
|
119
105
|
|
120
106
|
@lru_cache
|
121
107
|
def register_rbln_custom_paged_causal_attention():
|
122
108
|
torch.library.define(
|
123
109
|
"rbln_custom_ops::paged_causal_attn_decode",
|
124
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
110
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
125
111
|
)
|
126
112
|
|
127
113
|
@torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
|
@@ -149,28 +135,17 @@ def register_rbln_custom_paged_causal_attention():
|
|
149
135
|
- block_size: [] - Number of tokens per block
|
150
136
|
|
151
137
|
Returns:
|
152
|
-
|
153
|
-
- attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
154
|
-
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
155
|
-
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
138
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
156
139
|
"""
|
157
|
-
return
|
158
|
-
q,
|
159
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
160
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
161
|
-
)
|
140
|
+
return q
|
162
141
|
|
163
142
|
@register_fake("rbln_custom_ops::paged_causal_attn_decode")
|
164
143
|
def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
165
|
-
return
|
166
|
-
q,
|
167
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
168
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
169
|
-
)
|
144
|
+
return q
|
170
145
|
|
171
146
|
torch.library.define(
|
172
147
|
"rbln_custom_ops::paged_causal_attn_prefill",
|
173
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
148
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
174
149
|
)
|
175
150
|
|
176
151
|
@torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
|
@@ -197,23 +172,20 @@ def register_rbln_custom_paged_causal_attention():
|
|
197
172
|
- block_size: [] - Number of tokens per block
|
198
173
|
|
199
174
|
Returns:
|
200
|
-
|
201
|
-
- attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
202
|
-
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
203
|
-
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
175
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
204
176
|
"""
|
205
|
-
return q
|
177
|
+
return q
|
206
178
|
|
207
179
|
@register_fake("rbln_custom_ops::paged_causal_attn_prefill")
|
208
180
|
def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
209
|
-
return q
|
181
|
+
return q
|
210
182
|
|
211
183
|
|
212
184
|
@lru_cache
|
213
185
|
def register_rbln_custom_add_softmax_attention():
|
214
186
|
torch.library.define(
|
215
187
|
"rbln_custom_ops::add_softmax_attn_decode",
|
216
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor
|
188
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
|
217
189
|
)
|
218
190
|
|
219
191
|
@torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
|
@@ -240,21 +212,10 @@ def register_rbln_custom_add_softmax_attention():
|
|
240
212
|
- scale: [] - Attention scale factor
|
241
213
|
|
242
214
|
Returns:
|
243
|
-
|
244
|
-
- attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
245
|
-
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
246
|
-
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
215
|
+
Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
247
216
|
"""
|
248
|
-
return
|
249
|
-
q,
|
250
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
251
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
252
|
-
)
|
217
|
+
return q
|
253
218
|
|
254
219
|
@register_fake("rbln_custom_ops::add_softmax_attn_decode")
|
255
220
|
def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
256
|
-
return
|
257
|
-
q,
|
258
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
259
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
260
|
-
)
|
221
|
+
return q
|
optimum/rbln/ops/flash_attn.py
CHANGED
@@ -28,71 +28,55 @@ else:
|
|
28
28
|
def register_rbln_custom_paged_flash_attention():
|
29
29
|
torch.library.define(
|
30
30
|
"rbln_custom_ops::paged_flash_attn_decode",
|
31
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor
|
31
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
32
32
|
)
|
33
33
|
|
34
34
|
@torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
|
35
35
|
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
|
36
|
-
return
|
37
|
-
q,
|
38
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
39
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
40
|
-
)
|
36
|
+
return q
|
41
37
|
|
42
38
|
@register_fake("rbln_custom_ops::paged_flash_attn_decode")
|
43
39
|
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
|
44
|
-
return
|
45
|
-
q,
|
46
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
47
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
48
|
-
)
|
40
|
+
return q
|
49
41
|
|
50
42
|
torch.library.define(
|
51
43
|
"rbln_custom_ops::paged_flash_attn_prefill",
|
52
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor
|
44
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
53
45
|
)
|
54
46
|
|
55
47
|
@torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
|
56
48
|
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
|
57
|
-
return q
|
49
|
+
return q
|
58
50
|
|
59
51
|
@register_fake("rbln_custom_ops::paged_flash_attn_prefill")
|
60
52
|
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
|
61
|
-
return q
|
53
|
+
return q
|
62
54
|
|
63
55
|
|
64
56
|
@lru_cache
|
65
57
|
def register_rbln_custom_paged_flash_causal_attention():
|
66
58
|
torch.library.define(
|
67
59
|
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
68
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor
|
60
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
69
61
|
)
|
70
62
|
|
71
63
|
@torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
|
72
64
|
def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
73
|
-
return
|
74
|
-
q,
|
75
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
76
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
77
|
-
)
|
65
|
+
return q
|
78
66
|
|
79
67
|
@register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
|
80
68
|
def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
81
|
-
return
|
82
|
-
q,
|
83
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
84
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
85
|
-
)
|
69
|
+
return q
|
86
70
|
|
87
71
|
torch.library.define(
|
88
72
|
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
89
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor
|
73
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
90
74
|
)
|
91
75
|
|
92
76
|
@torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
|
93
77
|
def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
94
|
-
return q
|
78
|
+
return q
|
95
79
|
|
96
80
|
@register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
|
97
81
|
def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
98
|
-
return q
|
82
|
+
return q
|
@@ -45,10 +45,10 @@ def register_rbln_custom_cache_update():
|
|
45
45
|
|
46
46
|
# Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
|
47
47
|
# This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
|
48
|
-
|
48
|
+
cache.slice_scatter(state, dim=axis, start=s, end=e)
|
49
49
|
|
50
|
-
#
|
51
|
-
return
|
50
|
+
# 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
|
51
|
+
return torch.empty([256])
|
52
52
|
|
53
53
|
# Register a "fake" implementation of the "rbln_cache_update" operation.
|
54
54
|
# This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
|
@@ -57,4 +57,4 @@ def register_rbln_custom_cache_update():
|
|
57
57
|
# Return a tensor with the same shape as the input cache tensor.
|
58
58
|
# This is a placeholder for the abstract implementation and does not perform any actual computation.
|
59
59
|
# Like the actual implementation, the abstraction assumes in-place device-side updates.
|
60
|
-
return torch.
|
60
|
+
return torch.empty([256])
|
@@ -281,7 +281,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
281
281
|
_past_key_values.append(past_key_value)
|
282
282
|
past_key_values = _past_key_values
|
283
283
|
|
284
|
-
logit
|
284
|
+
logit = self.causal_lm(
|
285
285
|
input_ids=input_ids,
|
286
286
|
inputs_embeds=inputs_embeds,
|
287
287
|
attention_mask=attention_mask,
|
@@ -292,15 +292,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
292
292
|
block_tables=block_tables,
|
293
293
|
)
|
294
294
|
|
295
|
-
|
296
|
-
_present_key_values = ()
|
297
|
-
for i in range(self.num_hidden_layers):
|
298
|
-
key_states = present_key_values[i][0]
|
299
|
-
value_states = present_key_values[i][1]
|
300
|
-
_present_key_values = _present_key_values + (key_states, value_states)
|
301
|
-
present_key_values = _present_key_values
|
302
|
-
|
303
|
-
return logit, present_key_values
|
295
|
+
return logit
|
304
296
|
|
305
297
|
|
306
298
|
class DecoderOnlyForCausalLM(nn.Module):
|
@@ -353,7 +345,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
353
345
|
block_tables: Optional[torch.Tensor] = None,
|
354
346
|
):
|
355
347
|
# outputs
|
356
|
-
hidden_states
|
348
|
+
hidden_states = self.model(
|
357
349
|
input_ids=input_ids,
|
358
350
|
inputs_embeds=inputs_embeds,
|
359
351
|
attention_mask=attention_mask,
|
@@ -367,8 +359,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
367
359
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
368
360
|
|
369
361
|
logits = self._original_mod.lm_head(hidden_states)
|
370
|
-
|
371
|
-
return output
|
362
|
+
return logits
|
372
363
|
|
373
364
|
|
374
365
|
class DecoderOnlyModel(nn.Module):
|
@@ -484,20 +475,19 @@ class DecoderOnlyModel(nn.Module):
|
|
484
475
|
else:
|
485
476
|
seq_positions = cache_position[:, :1]
|
486
477
|
|
487
|
-
present_key_values = past_key_values
|
488
478
|
for layer in self.layers:
|
489
|
-
hidden_states
|
479
|
+
hidden_states = layer(
|
490
480
|
hidden_states=hidden_states,
|
491
481
|
attention_mask=attention_mask,
|
492
482
|
seq_positions=seq_positions,
|
493
|
-
past_key_values=
|
483
|
+
past_key_values=past_key_values,
|
494
484
|
cos=cos,
|
495
485
|
sin=sin,
|
496
486
|
block_tables=block_tables,
|
497
487
|
)
|
498
488
|
|
499
489
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
500
|
-
return hidden_states
|
490
|
+
return hidden_states
|
501
491
|
|
502
492
|
|
503
493
|
class DecoderOnlyLayer(nn.Module):
|
@@ -559,7 +549,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
559
549
|
residual = hidden_states
|
560
550
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
561
551
|
|
562
|
-
hidden_states
|
552
|
+
hidden_states = self.self_attn(
|
563
553
|
hidden_states=hidden_states,
|
564
554
|
attention_mask=attention_mask,
|
565
555
|
seq_positions=seq_positions,
|
@@ -576,7 +566,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
576
566
|
hidden_states = self._original_mod.mlp(hidden_states)
|
577
567
|
hidden_states = residual + hidden_states
|
578
568
|
|
579
|
-
return hidden_states
|
569
|
+
return hidden_states
|
580
570
|
|
581
571
|
|
582
572
|
class DecoderOnlyAttention(nn.Module):
|
@@ -678,7 +668,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
678
668
|
if batch_size > 1 and self.phase == "prefill":
|
679
669
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
680
670
|
|
681
|
-
attn_output
|
671
|
+
attn_output = self.attention(
|
682
672
|
query_states,
|
683
673
|
key_states,
|
684
674
|
value_states,
|
@@ -690,12 +680,9 @@ class DecoderOnlyAttention(nn.Module):
|
|
690
680
|
block_tables=block_tables,
|
691
681
|
block_size=self.kvcache_block_size,
|
692
682
|
)
|
693
|
-
key_states = key_state
|
694
|
-
value_states = value_state
|
695
683
|
|
696
684
|
attn_outputs = self.o_proj(attn_output)
|
697
|
-
|
698
|
-
return attn_outputs, past_key_values
|
685
|
+
return attn_outputs
|
699
686
|
|
700
687
|
|
701
688
|
class AttentionOp(nn.Module):
|
@@ -733,7 +720,7 @@ class AttentionOp(nn.Module):
|
|
733
720
|
scale: Scale applied to attn weights
|
734
721
|
|
735
722
|
Returns:
|
736
|
-
|
723
|
+
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
737
724
|
"""
|
738
725
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
739
726
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
@@ -756,7 +743,7 @@ class AttentionOp(nn.Module):
|
|
756
743
|
|
757
744
|
if self.phase == "decode":
|
758
745
|
if self.use_attention_mask:
|
759
|
-
attn_output
|
746
|
+
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
760
747
|
query_state,
|
761
748
|
key_state,
|
762
749
|
value_state,
|
@@ -769,7 +756,7 @@ class AttentionOp(nn.Module):
|
|
769
756
|
block_size,
|
770
757
|
)
|
771
758
|
else:
|
772
|
-
attn_output
|
759
|
+
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
|
773
760
|
query_state,
|
774
761
|
key_state,
|
775
762
|
value_state,
|
@@ -783,7 +770,7 @@ class AttentionOp(nn.Module):
|
|
783
770
|
|
784
771
|
else:
|
785
772
|
if self.use_attention_mask:
|
786
|
-
attn_output
|
773
|
+
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
787
774
|
query_state,
|
788
775
|
key_state,
|
789
776
|
value_state,
|
@@ -796,7 +783,7 @@ class AttentionOp(nn.Module):
|
|
796
783
|
block_size,
|
797
784
|
)
|
798
785
|
else:
|
799
|
-
attn_output
|
786
|
+
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
800
787
|
query_state,
|
801
788
|
key_state,
|
802
789
|
value_state,
|
@@ -812,7 +799,7 @@ class AttentionOp(nn.Module):
|
|
812
799
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
813
800
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
814
801
|
|
815
|
-
return attn_output
|
802
|
+
return attn_output
|
816
803
|
|
817
804
|
|
818
805
|
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
@@ -947,7 +934,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
947
934
|
if cos is not None and sin is not None:
|
948
935
|
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
949
936
|
|
950
|
-
attn_output
|
937
|
+
attn_output = self.attention(
|
951
938
|
query_states,
|
952
939
|
key_states,
|
953
940
|
value_states,
|
@@ -959,13 +946,9 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
959
946
|
block_tables=block_tables,
|
960
947
|
kvcache_block_size=self.kvcache_block_size,
|
961
948
|
)
|
962
|
-
key_states = key_state
|
963
|
-
value_states = value_state
|
964
949
|
|
965
950
|
attn_outputs = self.o_proj(attn_output)
|
966
|
-
|
967
|
-
|
968
|
-
return attn_outputs, past_key_values
|
951
|
+
return attn_outputs
|
969
952
|
|
970
953
|
|
971
954
|
class FlashAttentionOp(AttentionOp):
|
@@ -1019,7 +1002,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1019
1002
|
|
1020
1003
|
if self.phase == "decode":
|
1021
1004
|
if self.use_attention_mask:
|
1022
|
-
attn_output
|
1005
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1023
1006
|
query_state,
|
1024
1007
|
key_state,
|
1025
1008
|
value_state,
|
@@ -1033,7 +1016,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1033
1016
|
self.kvcache_partition_size,
|
1034
1017
|
)
|
1035
1018
|
else:
|
1036
|
-
attn_output
|
1019
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
|
1037
1020
|
query_state,
|
1038
1021
|
key_state,
|
1039
1022
|
value_state,
|
@@ -1047,7 +1030,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1047
1030
|
)
|
1048
1031
|
else:
|
1049
1032
|
if self.use_attention_mask:
|
1050
|
-
attn_output
|
1033
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1051
1034
|
query_state,
|
1052
1035
|
key_state,
|
1053
1036
|
value_state,
|
@@ -1061,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1061
1044
|
self.kvcache_partition_size,
|
1062
1045
|
)
|
1063
1046
|
else:
|
1064
|
-
attn_output
|
1047
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
1065
1048
|
query_state,
|
1066
1049
|
key_state,
|
1067
1050
|
value_state,
|
@@ -1079,4 +1062,4 @@ class FlashAttentionOp(AttentionOp):
|
|
1079
1062
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
1080
1063
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
1081
1064
|
|
1082
|
-
return attn_output
|
1065
|
+
return attn_output
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
+
import math
|
16
17
|
from collections import deque
|
17
18
|
from dataclasses import dataclass
|
18
19
|
from pathlib import Path
|
@@ -565,6 +566,72 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
565
566
|
|
566
567
|
return compile_model(quantize_config=quantize_config)
|
567
568
|
|
569
|
+
@classmethod
|
570
|
+
def get_maximum_num_blocks(
|
571
|
+
cls,
|
572
|
+
config: PretrainedConfig,
|
573
|
+
tensor_parallel_size: int,
|
574
|
+
kvcache_block_size: int,
|
575
|
+
nbits_per_param: int,
|
576
|
+
n_model_params: int,
|
577
|
+
) -> int:
|
578
|
+
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
579
|
+
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
580
|
+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
581
|
+
vocab_size = config.vocab_size
|
582
|
+
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
583
|
+
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
584
|
+
|
585
|
+
TARGET_DRAM_LIMIT = int(tensor_parallel_size * 15.7 * 2**30) # 16GB # TODO(jongho): 더 정확한 값
|
586
|
+
|
587
|
+
def align(x: int, nbytes: int) -> int:
|
588
|
+
return int(math.ceil(x / nbytes) * nbytes)
|
589
|
+
|
590
|
+
def align_2MB(x: int) -> int:
|
591
|
+
return align(x, 2 * 1024 * 1024)
|
592
|
+
|
593
|
+
def get_kernel_size() -> int:
|
594
|
+
# TODO: Implement
|
595
|
+
lm_heads_params = align(vocab_size, 64) * hidden_size
|
596
|
+
lm_heads_nbytes = (
|
597
|
+
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
598
|
+
)
|
599
|
+
|
600
|
+
params = n_model_params - lm_heads_params
|
601
|
+
layer_nbytes = (
|
602
|
+
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
603
|
+
* num_layers
|
604
|
+
* tensor_parallel_size
|
605
|
+
)
|
606
|
+
|
607
|
+
return layer_nbytes + lm_heads_nbytes
|
608
|
+
|
609
|
+
available_dram = TARGET_DRAM_LIMIT - get_kernel_size()
|
610
|
+
|
611
|
+
buffer = 2**30 # 1GB
|
612
|
+
if tensor_parallel_size <= 2:
|
613
|
+
buffer /= 4
|
614
|
+
|
615
|
+
available_dram -= buffer
|
616
|
+
|
617
|
+
def get_nbytes_per_block() -> int:
|
618
|
+
return (
|
619
|
+
align_2MB(
|
620
|
+
kvcache_block_size
|
621
|
+
* head_dim
|
622
|
+
* math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
|
623
|
+
* 2 # (fp16)
|
624
|
+
)
|
625
|
+
* num_layers
|
626
|
+
* 2 # (k, v)
|
627
|
+
* tensor_parallel_size
|
628
|
+
)
|
629
|
+
|
630
|
+
nbytes_per_block = get_nbytes_per_block()
|
631
|
+
n_blocks = available_dram // nbytes_per_block
|
632
|
+
|
633
|
+
return n_blocks, nbytes_per_block
|
634
|
+
|
568
635
|
@classmethod
|
569
636
|
def _get_rbln_config(
|
570
637
|
cls,
|
@@ -618,8 +685,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
618
685
|
else:
|
619
686
|
rbln_kvcache_block_size = rbln_kvcache_partition_len
|
620
687
|
|
621
|
-
|
622
|
-
|
688
|
+
max_num_blocks, nbytes_per_block = cls.get_maximum_num_blocks(
|
689
|
+
config=model_config,
|
690
|
+
tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
|
691
|
+
kvcache_block_size=rbln_kvcache_block_size,
|
692
|
+
nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
|
693
|
+
n_model_params=rbln_kwargs["n_model_params"],
|
694
|
+
)
|
695
|
+
model_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
|
696
|
+
rbln_kvcache_num_blocks = min(model_num_blocks, max_num_blocks)
|
697
|
+
|
698
|
+
required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
|
699
|
+
if rbln_kvcache_num_blocks < required_blocks:
|
700
|
+
rbln_kvcache_num_blocks = required_blocks
|
701
|
+
|
702
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
|
703
|
+
|
704
|
+
if rbln_kvcache_num_blocks < rbln_batch_size:
|
705
|
+
raise RuntimeError(
|
706
|
+
f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
|
707
|
+
"Ensure the number of blocks is at least equal to the batch size."
|
708
|
+
)
|
623
709
|
|
624
710
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
625
711
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
@@ -719,6 +805,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
719
805
|
"kvcache_block_size": rbln_kvcache_block_size,
|
720
806
|
"attn_impl": rbln_attn_impl,
|
721
807
|
"kvcache_num_blocks": rbln_kvcache_num_blocks,
|
808
|
+
"model_num_blocks": model_num_blocks,
|
809
|
+
"max_num_blocks": max_num_blocks,
|
810
|
+
"nbytes_per_block": nbytes_per_block,
|
722
811
|
}
|
723
812
|
)
|
724
813
|
|
@@ -114,11 +114,9 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
114
114
|
|
115
115
|
# 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
|
116
116
|
batch_axis = torch.tensor(1, dtype=torch.int16)
|
117
|
-
|
118
|
-
cross_key_values, cross_kv, b_idx[0], batch_axis
|
119
|
-
)
|
117
|
+
enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
|
120
118
|
|
121
|
-
return
|
119
|
+
return enc_out
|
122
120
|
|
123
121
|
|
124
122
|
class Seq2SeqDecoderWrapper(nn.Module):
|
@@ -193,7 +191,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
193
191
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
194
192
|
|
195
193
|
# decode
|
196
|
-
lm_logits
|
194
|
+
lm_logits = self.conditional_generation(
|
197
195
|
input_ids=input_ids,
|
198
196
|
attention_mask=attention_mask,
|
199
197
|
encoder_attention_mask=encoder_attention_mask,
|
@@ -203,9 +201,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
203
201
|
block_tables=block_tables,
|
204
202
|
)
|
205
203
|
|
206
|
-
|
207
|
-
|
208
|
-
return outputs
|
204
|
+
return lm_logits
|
209
205
|
|
210
206
|
|
211
207
|
class Seq2SeqForConditionalGeneration(nn.Module):
|
@@ -250,7 +246,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
|
|
250
246
|
cache_position,
|
251
247
|
block_tables: Optional[torch.Tensor] = None,
|
252
248
|
):
|
253
|
-
hidden_states
|
249
|
+
hidden_states = self.decoder(
|
254
250
|
input_ids=input_ids,
|
255
251
|
attention_mask=attention_mask,
|
256
252
|
encoder_attention_mask=encoder_attention_mask,
|
@@ -265,7 +261,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
|
|
265
261
|
|
266
262
|
lm_logits = self.lm_head(hidden_states)
|
267
263
|
|
268
|
-
return lm_logits
|
264
|
+
return lm_logits
|
269
265
|
|
270
266
|
|
271
267
|
class Seq2SeqDecoder(torch.nn.Module):
|
@@ -326,11 +322,10 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
326
322
|
hidden_states = self.apply_position_embedding(hidden_states, cache_position)
|
327
323
|
|
328
324
|
# iterate decoder_layer
|
329
|
-
self_present_key_values = ()
|
330
325
|
for decoder_layer, self_past_key_value, cross_past_key_value in zip(
|
331
326
|
self.layers, self_past_key_values, cross_past_key_values
|
332
327
|
):
|
333
|
-
hidden_states
|
328
|
+
hidden_states = decoder_layer(
|
334
329
|
hidden_states,
|
335
330
|
attention_mask=attention_mask,
|
336
331
|
encoder_attention_mask=encoder_attention_mask,
|
@@ -339,12 +334,11 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
339
334
|
cache_position=cache_position,
|
340
335
|
block_tables=block_tables,
|
341
336
|
)
|
342
|
-
self_present_key_values += self_present_key_value
|
343
337
|
|
344
338
|
if self.final_layer_norm is not None:
|
345
339
|
hidden_states = self.final_layer_norm(hidden_states)
|
346
340
|
|
347
|
-
return hidden_states
|
341
|
+
return hidden_states
|
348
342
|
|
349
343
|
|
350
344
|
class Seq2SeqDecoderLayer(torch.nn.Module):
|
@@ -404,7 +398,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
404
398
|
# Self Attention Block
|
405
399
|
residual = hidden_states
|
406
400
|
hidden_states = self.pre_self_attn_layer_norm(hidden_states)
|
407
|
-
hidden_states
|
401
|
+
hidden_states = self.self_attn(
|
408
402
|
hidden_states=hidden_states,
|
409
403
|
past_key_value=self_past_key_value,
|
410
404
|
attention_mask=attention_mask,
|
@@ -429,7 +423,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
429
423
|
# Feed-Forward Block
|
430
424
|
hidden_states = self.ff_layer(hidden_states)
|
431
425
|
|
432
|
-
return hidden_states
|
426
|
+
return hidden_states
|
433
427
|
|
434
428
|
|
435
429
|
class Seq2SeqSelfAttention(nn.Module):
|
@@ -492,12 +486,11 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
492
486
|
if attention_mask is not None:
|
493
487
|
args.insert(3, attention_mask.unsqueeze(2))
|
494
488
|
|
495
|
-
attn_output
|
489
|
+
attn_output = self.attn_decode(*args)
|
496
490
|
|
497
491
|
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
498
492
|
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
499
493
|
|
500
494
|
attn_output = self.out_proj(attn_output)
|
501
|
-
present_key_value = (key_states, value_states)
|
502
495
|
|
503
|
-
return attn_output
|
496
|
+
return attn_output
|
@@ -88,7 +88,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
88
88
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
89
89
|
|
90
90
|
# decode
|
91
|
-
lm_logits
|
91
|
+
lm_logits = self.conditional_generation(
|
92
92
|
input_ids=input_ids,
|
93
93
|
attention_mask=attention_mask,
|
94
94
|
encoder_attention_mask=encoder_attention_mask,
|
@@ -97,9 +97,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
97
97
|
cache_position=cache_position,
|
98
98
|
)
|
99
99
|
|
100
|
-
|
101
|
-
|
102
|
-
return outputs
|
100
|
+
return lm_logits
|
103
101
|
|
104
102
|
|
105
103
|
class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
@@ -187,7 +185,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
187
185
|
key_states = self._shape(key_states, -1, bsz)
|
188
186
|
value_states = self._shape(value_states, -1, bsz)
|
189
187
|
|
190
|
-
attn_output
|
188
|
+
attn_output = self.attn_decode(
|
191
189
|
query_states,
|
192
190
|
key_states,
|
193
191
|
value_states,
|
@@ -204,9 +202,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
204
202
|
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
205
203
|
|
206
204
|
attn_output = self.out_proj(attn_output)
|
207
|
-
|
208
|
-
|
209
|
-
return attn_output, present_key_value
|
205
|
+
return attn_output
|
210
206
|
|
211
207
|
|
212
208
|
class T5CrossAttention(nn.Module):
|
@@ -78,9 +78,9 @@ class WhisperEncoderWrapper(torch.nn.Module):
|
|
78
78
|
# 3. update cross_attention's past_key_value to the device-dram for optimization.
|
79
79
|
bidx = torch.tensor(0, dtype=torch.int16)
|
80
80
|
axis = torch.tensor(1, dtype=torch.int16)
|
81
|
-
|
81
|
+
enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
|
82
82
|
|
83
|
-
return
|
83
|
+
return enc_output
|
84
84
|
|
85
85
|
|
86
86
|
class WhisperDecoderWrapper(torch.nn.Module):
|
@@ -119,7 +119,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
119
119
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
120
120
|
|
121
121
|
# Decode
|
122
|
-
sequence_output,
|
122
|
+
sequence_output, cross_attentions = self.decoder(
|
123
123
|
input_ids=decoder_input_ids,
|
124
124
|
attention_mask=decoder_attention_mask,
|
125
125
|
cache_position=cache_position,
|
@@ -128,9 +128,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
128
128
|
)
|
129
129
|
|
130
130
|
lm_logits = self.proj_out(sequence_output)
|
131
|
-
|
132
131
|
outputs = (lm_logits,)
|
133
|
-
outputs += self_present_key_values
|
134
132
|
|
135
133
|
if self.output_attentions:
|
136
134
|
# deocder's cross attention is used for token_timestamps
|
@@ -168,26 +166,23 @@ class WhisperDecoder(nn.Module):
|
|
168
166
|
# prepare casual_attn_mask
|
169
167
|
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
|
170
168
|
|
171
|
-
self_present_key_values = ()
|
172
169
|
cross_attentions = ()
|
173
170
|
# iterate decoder_layer
|
174
171
|
for self_past_key_value, cross_past_key_value, decoder_layer in zip(
|
175
172
|
self_past_key_values, cross_past_key_values, self.layers
|
176
173
|
):
|
177
|
-
|
174
|
+
hidden_states, cross_attn_weights = decoder_layer(
|
178
175
|
hidden_states,
|
179
176
|
attention_mask=attention_mask,
|
180
177
|
self_past_key_value=self_past_key_value,
|
181
178
|
cross_past_key_value=cross_past_key_value,
|
182
179
|
cache_position=cache_position,
|
183
180
|
)
|
184
|
-
|
185
|
-
self_present_key_values += layer_outputs[1]
|
186
|
-
cross_attentions += (layer_outputs[2],)
|
181
|
+
cross_attentions += (cross_attn_weights,)
|
187
182
|
|
188
183
|
hidden_states = self.layer_norm(hidden_states)
|
189
184
|
|
190
|
-
return hidden_states,
|
185
|
+
return hidden_states, cross_attentions
|
191
186
|
|
192
187
|
|
193
188
|
class WhisperDecoderLayer(nn.Module):
|
@@ -214,7 +209,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
214
209
|
# Self Attention Block
|
215
210
|
residual = hidden_states
|
216
211
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
217
|
-
hidden_states
|
212
|
+
hidden_states = self.self_attn(
|
218
213
|
hidden_states=hidden_states,
|
219
214
|
past_key_value=self_past_key_value,
|
220
215
|
attention_mask=attention_mask,
|
@@ -238,7 +233,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
238
233
|
hidden_states = self.fc2(hidden_states)
|
239
234
|
hidden_states = residual + hidden_states
|
240
235
|
|
241
|
-
return hidden_states,
|
236
|
+
return hidden_states, cross_attn_weights
|
242
237
|
|
243
238
|
|
244
239
|
class WhisperAttention(nn.Module):
|
@@ -276,7 +271,7 @@ class WhisperSelfAttention(WhisperAttention):
|
|
276
271
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
277
272
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
278
273
|
|
279
|
-
attn_output
|
274
|
+
attn_output = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
|
280
275
|
query_states,
|
281
276
|
key_states,
|
282
277
|
value_states,
|
@@ -292,7 +287,7 @@ class WhisperSelfAttention(WhisperAttention):
|
|
292
287
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
293
288
|
attn_output = self.out_proj(attn_output)
|
294
289
|
|
295
|
-
return attn_output
|
290
|
+
return attn_output
|
296
291
|
|
297
292
|
|
298
293
|
class WhisperCrossAttention(WhisperAttention):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.3a4
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|
@@ -1,7 +1,7 @@
|
|
1
1
|
optimum/rbln/__init__.py,sha256=eHi15YM3989AcX52jka9rUmgAtlp1PHqMNwBEdOfuu8,6554
|
2
|
-
optimum/rbln/__version__.py,sha256=
|
3
|
-
optimum/rbln/modeling.py,sha256=
|
4
|
-
optimum/rbln/modeling_base.py,sha256=
|
2
|
+
optimum/rbln/__version__.py,sha256=MLlg_138GxyhciEP0ZB5dPN8vriXkicRnaZiwqygxOY,519
|
3
|
+
optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
|
4
|
+
optimum/rbln/modeling_base.py,sha256=Ow73GVJF1N5cDFO8_rgirtGj1wC-cXBDyqXHW5PCybA,22270
|
5
5
|
optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
|
6
6
|
optimum/rbln/diffusers/__init__.py,sha256=pOyoXv3-JRzTBSwPKbgLS9H6F2K9dJdReEmpGhcLQYU,3283
|
7
7
|
optimum/rbln/diffusers/modeling_diffusers.py,sha256=zqVNgH9oeOx2iNE7VsW_FinVf4s6G5Idyh4TKz7XJJg,21116
|
@@ -40,9 +40,9 @@ optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_x
|
|
40
40
|
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=3aB1Rw-OgKytQOHwOaShbEvq_XVHPOGvsGm8pstEmKU,930
|
41
41
|
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=MzVP1wscaO1sUIiBIPJqG6zuGyez9VUbA42-JSIm-mk,930
|
42
42
|
optimum/rbln/ops/__init__.py,sha256=TxOmsN0u3PmyK4Sb89qbiC4rePOlkvUT7Lm6wVoTnY0,941
|
43
|
-
optimum/rbln/ops/attn.py,sha256=
|
44
|
-
optimum/rbln/ops/flash_attn.py,sha256=
|
45
|
-
optimum/rbln/ops/kv_cache_update.py,sha256=
|
43
|
+
optimum/rbln/ops/attn.py,sha256=3EqU63Oj4zI4rLbkRycorsscXeD-IpKzt9N1MhkMa5o,10374
|
44
|
+
optimum/rbln/ops/flash_attn.py,sha256=wfyiCxDGf034IngzwRU160R7_DlKYpd-uWT0BDEGFks,3408
|
45
|
+
optimum/rbln/ops/kv_cache_update.py,sha256=pxf8kAptPaQF5xE8qItvmlFOq_sgim6ZERD7AVaOtec,3221
|
46
46
|
optimum/rbln/transformers/__init__.py,sha256=AGo3BqVIZrsOzYsQAnnQ25HCstTPBclrXbvvUxVMlqE,4255
|
47
47
|
optimum/rbln/transformers/modeling_alias.py,sha256=yx7FnZQWAnrWzivaO5hI7T6i-fyLzt2tMIXG2oDNbPo,1657
|
48
48
|
optimum/rbln/transformers/modeling_generic.py,sha256=aaZWsqVDCRvH03q-Wen7DMfLr7Gy-u-I0mTw0aYqWjk,18195
|
@@ -59,8 +59,8 @@ optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaM
|
|
59
59
|
optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
|
60
60
|
optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
|
61
61
|
optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
|
62
|
-
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=
|
63
|
-
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=
|
62
|
+
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=7OIKteJLKNxOLOg0w3lLOM7TxZovQn4jkglI9wRkrtQ,40609
|
63
|
+
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=W9HnxJoTz78Wc4X5Q3sMSHhMTSa7-9uQCFlnqNVozvA,38932
|
64
64
|
optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
|
65
65
|
optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
|
66
66
|
optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
|
@@ -91,16 +91,16 @@ optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz
|
|
91
91
|
optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
|
92
92
|
optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
|
93
93
|
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=NPfJf9Uk_bYOae7hXGHwteGiWH0va63Z-D93RmAMENg,17611
|
94
|
-
optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=
|
94
|
+
optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
|
95
95
|
optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
|
96
96
|
optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=nKRR3eH1EAu1YkKvhlqGyTrJXDRd-IWB5LOeG9jrcb4,16021
|
97
|
-
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=
|
97
|
+
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=AArCQhZRETVM583wlIRzMFOSYq7t2nzxaAeyhZxyxKk,9508
|
98
98
|
optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
|
99
99
|
optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
|
100
100
|
optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
|
101
101
|
optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
|
102
102
|
optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=0nBADNxE0A1ozBbRutTBvxpo_Y1qkOycT_zronkN-ZU,15840
|
103
|
-
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=
|
103
|
+
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=Yn6yFpmw6IQbWlnpIMAdEUsNF6huXgaKzGMUZbhSLdo,12572
|
104
104
|
optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
|
105
105
|
optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
|
106
106
|
optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -114,7 +114,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
|
|
114
114
|
optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
|
115
115
|
optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
|
116
116
|
optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
|
117
|
-
optimum_rbln-0.7.
|
118
|
-
optimum_rbln-0.7.
|
119
|
-
optimum_rbln-0.7.
|
120
|
-
optimum_rbln-0.7.
|
117
|
+
optimum_rbln-0.7.3a4.dist-info/METADATA,sha256=8VNTOVgsgFtcFUuZ9VEeRQfC2LEB60OFmW92hlJo8V8,5300
|
118
|
+
optimum_rbln-0.7.3a4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
119
|
+
optimum_rbln-0.7.3a4.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
120
|
+
optimum_rbln-0.7.3a4.dist-info/RECORD,,
|
File without changes
|
File without changes
|