optimum-rbln 0.7.3a2__py3-none-any.whl → 0.7.3a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/modeling.py +3 -0
- optimum/rbln/modeling_base.py +12 -0
- optimum/rbln/ops/attn.py +20 -59
- optimum/rbln/ops/flash_attn.py +12 -28
- optimum/rbln/ops/kv_cache_update.py +4 -4
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +24 -41
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +139 -54
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -19
- optimum/rbln/transformers/models/t5/t5_architecture.py +4 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
- {optimum_rbln-0.7.3a2.dist-info → optimum_rbln-0.7.3a4.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3a2.dist-info → optimum_rbln-0.7.3a4.dist-info}/RECORD +15 -15
- {optimum_rbln-0.7.3a2.dist-info → optimum_rbln-0.7.3a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3a2.dist-info → optimum_rbln-0.7.3a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.7.
|
21
|
-
__version_tuple__ = version_tuple = (0, 7, 3)
|
20
|
+
__version__ = version = '0.7.3a4'
|
21
|
+
__version_tuple__ = version_tuple = (0, 7, 3, 'a4')
|
optimum/rbln/modeling.py
CHANGED
@@ -134,6 +134,9 @@ class RBLNModel(RBLNBaseModel):
|
|
134
134
|
for preprocessor in preprocessors:
|
135
135
|
preprocessor.save_pretrained(save_dir_path / subfolder)
|
136
136
|
|
137
|
+
# ad-hoc
|
138
|
+
rbln_kwargs["n_model_params"] = sum(p.numel() for p in model.parameters())
|
139
|
+
|
137
140
|
# Get compilation arguments (e.g. input_info)
|
138
141
|
rbln_config: RBLNConfig = cls.get_rbln_config(
|
139
142
|
preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
|
optimum/rbln/modeling_base.py
CHANGED
@@ -282,6 +282,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
282
282
|
**kwargs,
|
283
283
|
)
|
284
284
|
|
285
|
+
@classmethod
|
286
|
+
def _check_compiled_models(
|
287
|
+
cls, compiled_models: Dict[str, rebel.RBLNCompiledModel], rbln_config: RBLNConfig, config: "PretrainedConfig"
|
288
|
+
):
|
289
|
+
# check compiled model can create runtimes.
|
290
|
+
# this logic currently only works in LLM
|
291
|
+
# fail when LLM model using Paged Attention can't guarantee max sequence length
|
292
|
+
pass
|
293
|
+
|
285
294
|
@classmethod
|
286
295
|
def _from_compiled_models(
|
287
296
|
cls,
|
@@ -295,6 +304,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
295
304
|
):
|
296
305
|
if isinstance(model_save_dir, str):
|
297
306
|
model_save_dir = Path(model_save_dir)
|
307
|
+
|
308
|
+
cls._check_compiled_models(compiled_models=rbln_compiled_models, rbln_config=rbln_config, config=config)
|
309
|
+
|
298
310
|
# FIXME:: Should we convert it?
|
299
311
|
compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
|
300
312
|
rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
|
optimum/rbln/ops/attn.py
CHANGED
@@ -28,7 +28,7 @@ else:
|
|
28
28
|
def register_rbln_custom_paged_attention():
|
29
29
|
torch.library.define(
|
30
30
|
"rbln_custom_ops::paged_attn_decode",
|
31
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
31
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
32
32
|
)
|
33
33
|
|
34
34
|
@torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
|
@@ -57,28 +57,17 @@ def register_rbln_custom_paged_attention():
|
|
57
57
|
- block_size: [] - Number of tokens per block
|
58
58
|
|
59
59
|
Returns:
|
60
|
-
|
61
|
-
- attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
62
|
-
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
63
|
-
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
60
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
64
61
|
"""
|
65
|
-
return
|
66
|
-
q,
|
67
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
68
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
69
|
-
)
|
62
|
+
return q
|
70
63
|
|
71
64
|
@register_fake("rbln_custom_ops::paged_attn_decode")
|
72
65
|
def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
|
73
|
-
return
|
74
|
-
q,
|
75
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
76
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
77
|
-
)
|
66
|
+
return q
|
78
67
|
|
79
68
|
torch.library.define(
|
80
69
|
"rbln_custom_ops::paged_attn_prefill",
|
81
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
70
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
82
71
|
)
|
83
72
|
|
84
73
|
@torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
|
@@ -105,23 +94,20 @@ def register_rbln_custom_paged_attention():
|
|
105
94
|
- block_size: [] - Number of tokens per block
|
106
95
|
|
107
96
|
Returns:
|
108
|
-
|
109
|
-
- attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
110
|
-
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
111
|
-
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
97
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
112
98
|
"""
|
113
|
-
return q
|
99
|
+
return q
|
114
100
|
|
115
101
|
@register_fake("rbln_custom_ops::paged_attn_prefill")
|
116
102
|
def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
|
117
|
-
return q
|
103
|
+
return q
|
118
104
|
|
119
105
|
|
120
106
|
@lru_cache
|
121
107
|
def register_rbln_custom_paged_causal_attention():
|
122
108
|
torch.library.define(
|
123
109
|
"rbln_custom_ops::paged_causal_attn_decode",
|
124
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
110
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
125
111
|
)
|
126
112
|
|
127
113
|
@torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
|
@@ -149,28 +135,17 @@ def register_rbln_custom_paged_causal_attention():
|
|
149
135
|
- block_size: [] - Number of tokens per block
|
150
136
|
|
151
137
|
Returns:
|
152
|
-
|
153
|
-
- attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
154
|
-
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
155
|
-
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
138
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
156
139
|
"""
|
157
|
-
return
|
158
|
-
q,
|
159
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
160
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
161
|
-
)
|
140
|
+
return q
|
162
141
|
|
163
142
|
@register_fake("rbln_custom_ops::paged_causal_attn_decode")
|
164
143
|
def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
165
|
-
return
|
166
|
-
q,
|
167
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
168
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
169
|
-
)
|
144
|
+
return q
|
170
145
|
|
171
146
|
torch.library.define(
|
172
147
|
"rbln_custom_ops::paged_causal_attn_prefill",
|
173
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor
|
148
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
|
174
149
|
)
|
175
150
|
|
176
151
|
@torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
|
@@ -197,23 +172,20 @@ def register_rbln_custom_paged_causal_attention():
|
|
197
172
|
- block_size: [] - Number of tokens per block
|
198
173
|
|
199
174
|
Returns:
|
200
|
-
|
201
|
-
- attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
202
|
-
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
203
|
-
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
175
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
204
176
|
"""
|
205
|
-
return q
|
177
|
+
return q
|
206
178
|
|
207
179
|
@register_fake("rbln_custom_ops::paged_causal_attn_prefill")
|
208
180
|
def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
|
209
|
-
return q
|
181
|
+
return q
|
210
182
|
|
211
183
|
|
212
184
|
@lru_cache
|
213
185
|
def register_rbln_custom_add_softmax_attention():
|
214
186
|
torch.library.define(
|
215
187
|
"rbln_custom_ops::add_softmax_attn_decode",
|
216
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor
|
188
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
|
217
189
|
)
|
218
190
|
|
219
191
|
@torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
|
@@ -240,21 +212,10 @@ def register_rbln_custom_add_softmax_attention():
|
|
240
212
|
- scale: [] - Attention scale factor
|
241
213
|
|
242
214
|
Returns:
|
243
|
-
|
244
|
-
- attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
245
|
-
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
246
|
-
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
215
|
+
Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
247
216
|
"""
|
248
|
-
return
|
249
|
-
q,
|
250
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
251
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
252
|
-
)
|
217
|
+
return q
|
253
218
|
|
254
219
|
@register_fake("rbln_custom_ops::add_softmax_attn_decode")
|
255
220
|
def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
256
|
-
return
|
257
|
-
q,
|
258
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
259
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
260
|
-
)
|
221
|
+
return q
|
optimum/rbln/ops/flash_attn.py
CHANGED
@@ -28,71 +28,55 @@ else:
|
|
28
28
|
def register_rbln_custom_paged_flash_attention():
|
29
29
|
torch.library.define(
|
30
30
|
"rbln_custom_ops::paged_flash_attn_decode",
|
31
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor
|
31
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
32
32
|
)
|
33
33
|
|
34
34
|
@torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
|
35
35
|
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
|
36
|
-
return
|
37
|
-
q,
|
38
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
39
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
40
|
-
)
|
36
|
+
return q
|
41
37
|
|
42
38
|
@register_fake("rbln_custom_ops::paged_flash_attn_decode")
|
43
39
|
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
|
44
|
-
return
|
45
|
-
q,
|
46
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
47
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
48
|
-
)
|
40
|
+
return q
|
49
41
|
|
50
42
|
torch.library.define(
|
51
43
|
"rbln_custom_ops::paged_flash_attn_prefill",
|
52
|
-
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor
|
44
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
53
45
|
)
|
54
46
|
|
55
47
|
@torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
|
56
48
|
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
|
57
|
-
return q
|
49
|
+
return q
|
58
50
|
|
59
51
|
@register_fake("rbln_custom_ops::paged_flash_attn_prefill")
|
60
52
|
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
|
61
|
-
return q
|
53
|
+
return q
|
62
54
|
|
63
55
|
|
64
56
|
@lru_cache
|
65
57
|
def register_rbln_custom_paged_flash_causal_attention():
|
66
58
|
torch.library.define(
|
67
59
|
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
68
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor
|
60
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
69
61
|
)
|
70
62
|
|
71
63
|
@torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
|
72
64
|
def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
73
|
-
return
|
74
|
-
q,
|
75
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
76
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
77
|
-
)
|
65
|
+
return q
|
78
66
|
|
79
67
|
@register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
|
80
68
|
def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
81
|
-
return
|
82
|
-
q,
|
83
|
-
torch.empty(*kcache.shape, device=kcache.device),
|
84
|
-
torch.empty(*vcache.shape, device=vcache.device),
|
85
|
-
)
|
69
|
+
return q
|
86
70
|
|
87
71
|
torch.library.define(
|
88
72
|
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
89
|
-
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor
|
73
|
+
"(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
|
90
74
|
)
|
91
75
|
|
92
76
|
@torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
|
93
77
|
def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
94
|
-
return q
|
78
|
+
return q
|
95
79
|
|
96
80
|
@register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
|
97
81
|
def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
|
98
|
-
return q
|
82
|
+
return q
|
@@ -45,10 +45,10 @@ def register_rbln_custom_cache_update():
|
|
45
45
|
|
46
46
|
# Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
|
47
47
|
# This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
|
48
|
-
|
48
|
+
cache.slice_scatter(state, dim=axis, start=s, end=e)
|
49
49
|
|
50
|
-
#
|
51
|
-
return
|
50
|
+
# 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
|
51
|
+
return torch.empty([256])
|
52
52
|
|
53
53
|
# Register a "fake" implementation of the "rbln_cache_update" operation.
|
54
54
|
# This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
|
@@ -57,4 +57,4 @@ def register_rbln_custom_cache_update():
|
|
57
57
|
# Return a tensor with the same shape as the input cache tensor.
|
58
58
|
# This is a placeholder for the abstract implementation and does not perform any actual computation.
|
59
59
|
# Like the actual implementation, the abstraction assumes in-place device-side updates.
|
60
|
-
return torch.
|
60
|
+
return torch.empty([256])
|
@@ -281,7 +281,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
281
281
|
_past_key_values.append(past_key_value)
|
282
282
|
past_key_values = _past_key_values
|
283
283
|
|
284
|
-
logit
|
284
|
+
logit = self.causal_lm(
|
285
285
|
input_ids=input_ids,
|
286
286
|
inputs_embeds=inputs_embeds,
|
287
287
|
attention_mask=attention_mask,
|
@@ -292,15 +292,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
292
292
|
block_tables=block_tables,
|
293
293
|
)
|
294
294
|
|
295
|
-
|
296
|
-
_present_key_values = ()
|
297
|
-
for i in range(self.num_hidden_layers):
|
298
|
-
key_states = present_key_values[i][0]
|
299
|
-
value_states = present_key_values[i][1]
|
300
|
-
_present_key_values = _present_key_values + (key_states, value_states)
|
301
|
-
present_key_values = _present_key_values
|
302
|
-
|
303
|
-
return logit, present_key_values
|
295
|
+
return logit
|
304
296
|
|
305
297
|
|
306
298
|
class DecoderOnlyForCausalLM(nn.Module):
|
@@ -353,7 +345,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
353
345
|
block_tables: Optional[torch.Tensor] = None,
|
354
346
|
):
|
355
347
|
# outputs
|
356
|
-
hidden_states
|
348
|
+
hidden_states = self.model(
|
357
349
|
input_ids=input_ids,
|
358
350
|
inputs_embeds=inputs_embeds,
|
359
351
|
attention_mask=attention_mask,
|
@@ -367,8 +359,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
367
359
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
368
360
|
|
369
361
|
logits = self._original_mod.lm_head(hidden_states)
|
370
|
-
|
371
|
-
return output
|
362
|
+
return logits
|
372
363
|
|
373
364
|
|
374
365
|
class DecoderOnlyModel(nn.Module):
|
@@ -484,20 +475,19 @@ class DecoderOnlyModel(nn.Module):
|
|
484
475
|
else:
|
485
476
|
seq_positions = cache_position[:, :1]
|
486
477
|
|
487
|
-
present_key_values = past_key_values
|
488
478
|
for layer in self.layers:
|
489
|
-
hidden_states
|
479
|
+
hidden_states = layer(
|
490
480
|
hidden_states=hidden_states,
|
491
481
|
attention_mask=attention_mask,
|
492
482
|
seq_positions=seq_positions,
|
493
|
-
past_key_values=
|
483
|
+
past_key_values=past_key_values,
|
494
484
|
cos=cos,
|
495
485
|
sin=sin,
|
496
486
|
block_tables=block_tables,
|
497
487
|
)
|
498
488
|
|
499
489
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
500
|
-
return hidden_states
|
490
|
+
return hidden_states
|
501
491
|
|
502
492
|
|
503
493
|
class DecoderOnlyLayer(nn.Module):
|
@@ -559,7 +549,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
559
549
|
residual = hidden_states
|
560
550
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
561
551
|
|
562
|
-
hidden_states
|
552
|
+
hidden_states = self.self_attn(
|
563
553
|
hidden_states=hidden_states,
|
564
554
|
attention_mask=attention_mask,
|
565
555
|
seq_positions=seq_positions,
|
@@ -576,7 +566,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
576
566
|
hidden_states = self._original_mod.mlp(hidden_states)
|
577
567
|
hidden_states = residual + hidden_states
|
578
568
|
|
579
|
-
return hidden_states
|
569
|
+
return hidden_states
|
580
570
|
|
581
571
|
|
582
572
|
class DecoderOnlyAttention(nn.Module):
|
@@ -678,7 +668,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
678
668
|
if batch_size > 1 and self.phase == "prefill":
|
679
669
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
680
670
|
|
681
|
-
attn_output
|
671
|
+
attn_output = self.attention(
|
682
672
|
query_states,
|
683
673
|
key_states,
|
684
674
|
value_states,
|
@@ -690,12 +680,9 @@ class DecoderOnlyAttention(nn.Module):
|
|
690
680
|
block_tables=block_tables,
|
691
681
|
block_size=self.kvcache_block_size,
|
692
682
|
)
|
693
|
-
key_states = key_state
|
694
|
-
value_states = value_state
|
695
683
|
|
696
684
|
attn_outputs = self.o_proj(attn_output)
|
697
|
-
|
698
|
-
return attn_outputs, past_key_values
|
685
|
+
return attn_outputs
|
699
686
|
|
700
687
|
|
701
688
|
class AttentionOp(nn.Module):
|
@@ -733,7 +720,7 @@ class AttentionOp(nn.Module):
|
|
733
720
|
scale: Scale applied to attn weights
|
734
721
|
|
735
722
|
Returns:
|
736
|
-
|
723
|
+
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
737
724
|
"""
|
738
725
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
739
726
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
@@ -756,7 +743,7 @@ class AttentionOp(nn.Module):
|
|
756
743
|
|
757
744
|
if self.phase == "decode":
|
758
745
|
if self.use_attention_mask:
|
759
|
-
attn_output
|
746
|
+
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
760
747
|
query_state,
|
761
748
|
key_state,
|
762
749
|
value_state,
|
@@ -769,7 +756,7 @@ class AttentionOp(nn.Module):
|
|
769
756
|
block_size,
|
770
757
|
)
|
771
758
|
else:
|
772
|
-
attn_output
|
759
|
+
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
|
773
760
|
query_state,
|
774
761
|
key_state,
|
775
762
|
value_state,
|
@@ -783,7 +770,7 @@ class AttentionOp(nn.Module):
|
|
783
770
|
|
784
771
|
else:
|
785
772
|
if self.use_attention_mask:
|
786
|
-
attn_output
|
773
|
+
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
787
774
|
query_state,
|
788
775
|
key_state,
|
789
776
|
value_state,
|
@@ -796,7 +783,7 @@ class AttentionOp(nn.Module):
|
|
796
783
|
block_size,
|
797
784
|
)
|
798
785
|
else:
|
799
|
-
attn_output
|
786
|
+
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
800
787
|
query_state,
|
801
788
|
key_state,
|
802
789
|
value_state,
|
@@ -812,7 +799,7 @@ class AttentionOp(nn.Module):
|
|
812
799
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
813
800
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
814
801
|
|
815
|
-
return attn_output
|
802
|
+
return attn_output
|
816
803
|
|
817
804
|
|
818
805
|
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
@@ -947,7 +934,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
947
934
|
if cos is not None and sin is not None:
|
948
935
|
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
949
936
|
|
950
|
-
attn_output
|
937
|
+
attn_output = self.attention(
|
951
938
|
query_states,
|
952
939
|
key_states,
|
953
940
|
value_states,
|
@@ -959,13 +946,9 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
959
946
|
block_tables=block_tables,
|
960
947
|
kvcache_block_size=self.kvcache_block_size,
|
961
948
|
)
|
962
|
-
key_states = key_state
|
963
|
-
value_states = value_state
|
964
949
|
|
965
950
|
attn_outputs = self.o_proj(attn_output)
|
966
|
-
|
967
|
-
|
968
|
-
return attn_outputs, past_key_values
|
951
|
+
return attn_outputs
|
969
952
|
|
970
953
|
|
971
954
|
class FlashAttentionOp(AttentionOp):
|
@@ -1019,7 +1002,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1019
1002
|
|
1020
1003
|
if self.phase == "decode":
|
1021
1004
|
if self.use_attention_mask:
|
1022
|
-
attn_output
|
1005
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1023
1006
|
query_state,
|
1024
1007
|
key_state,
|
1025
1008
|
value_state,
|
@@ -1033,7 +1016,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1033
1016
|
self.kvcache_partition_size,
|
1034
1017
|
)
|
1035
1018
|
else:
|
1036
|
-
attn_output
|
1019
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
|
1037
1020
|
query_state,
|
1038
1021
|
key_state,
|
1039
1022
|
value_state,
|
@@ -1047,7 +1030,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1047
1030
|
)
|
1048
1031
|
else:
|
1049
1032
|
if self.use_attention_mask:
|
1050
|
-
attn_output
|
1033
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1051
1034
|
query_state,
|
1052
1035
|
key_state,
|
1053
1036
|
value_state,
|
@@ -1061,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
|
|
1061
1044
|
self.kvcache_partition_size,
|
1062
1045
|
)
|
1063
1046
|
else:
|
1064
|
-
attn_output
|
1047
|
+
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
1065
1048
|
query_state,
|
1066
1049
|
key_state,
|
1067
1050
|
value_state,
|
@@ -1079,4 +1062,4 @@ class FlashAttentionOp(AttentionOp):
|
|
1079
1062
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
1080
1063
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
1081
1064
|
|
1082
|
-
return attn_output
|
1065
|
+
return attn_output
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
+
import math
|
16
17
|
from collections import deque
|
17
18
|
from dataclasses import dataclass
|
18
19
|
from pathlib import Path
|
@@ -54,7 +55,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
54
55
|
block_tables: torch.Tensor,
|
55
56
|
free_block_pool: Deque,
|
56
57
|
kvcache_block_size: int,
|
57
|
-
kvcache_num_blocks: int,
|
58
58
|
use_attention_mask: bool,
|
59
59
|
attn_impl: str,
|
60
60
|
**kwargs: Any,
|
@@ -72,7 +72,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
72
72
|
self.free_block_pool = free_block_pool
|
73
73
|
|
74
74
|
self.kvcache_block_size = kvcache_block_size
|
75
|
-
self.empty_block =
|
75
|
+
self.empty_block = -1
|
76
76
|
self.attn_impl = attn_impl
|
77
77
|
|
78
78
|
if self.phase == "prefill":
|
@@ -97,58 +97,61 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
97
97
|
torch.Tensor: Updated block tables.
|
98
98
|
"""
|
99
99
|
|
100
|
-
|
100
|
+
NO_BLOCKS_ERROR = (
|
101
|
+
"No memory blocks are available for allocation."
|
102
|
+
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln."
|
103
|
+
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html)."
|
104
|
+
"Using vllm-rbln should fix this issue and enhance inference performance."
|
105
|
+
)
|
106
|
+
|
107
|
+
def update_block(batch_idx: int, block_idx: int):
|
101
108
|
"""
|
102
|
-
Helper function to update the block table for a given batch index and block index.
|
103
109
|
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
104
|
-
|
105
|
-
Args:
|
106
|
-
batch_idx (int): Batch index.
|
107
|
-
block_idx (int): Block index.
|
108
|
-
|
109
|
-
Raises:
|
110
|
-
RuntimeError: Raised if no available blocks are found in the free_block_pool.
|
111
110
|
"""
|
112
111
|
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
113
112
|
if self.free_block_pool:
|
114
113
|
block = self.free_block_pool.popleft()
|
115
114
|
self.block_tables[batch_idx][block_idx] = block
|
116
115
|
else:
|
117
|
-
raise RuntimeError(
|
116
|
+
raise RuntimeError(NO_BLOCKS_ERROR)
|
118
117
|
|
119
|
-
|
120
|
-
|
121
|
-
|
118
|
+
def replace_empty_block(block_tables: torch.Tensor):
|
119
|
+
"""
|
120
|
+
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
121
|
+
"""
|
122
|
+
if not torch.any(block_tables == self.empty_block):
|
123
|
+
return block_tables.clone()
|
124
|
+
elif self.free_block_pool:
|
125
|
+
_free_block = self.free_block_pool[0]
|
126
|
+
return torch.where(block_tables == self.empty_block, _free_block, block_tables)
|
122
127
|
else:
|
123
|
-
|
124
|
-
|
128
|
+
raise RuntimeError(NO_BLOCKS_ERROR)
|
129
|
+
|
130
|
+
if self.phase == "prefill":
|
131
|
+
# Track previously used blocks and return them to the free_block_pool and
|
132
|
+
# reset the current batch's block table to empty blocks
|
133
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
134
|
+
self.free_block_pool.extend(prev_blocks)
|
135
|
+
self.block_tables[batch_idx].fill_(self.empty_block)
|
136
|
+
|
137
|
+
# Get the start (s) and end (e) positions from cache_position and
|
138
|
+
# iterate over the cache positions to allocate necessary blocks
|
139
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
140
|
+
for position in range(s, e + 1, self.kvcache_block_size):
|
141
|
+
block_idx = position // self.kvcache_block_size
|
142
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
143
|
+
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
144
|
+
update_block(batch_idx, block_idx)
|
145
|
+
|
146
|
+
return replace_empty_block(self.block_tables[batch_idx])
|
147
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
125
148
|
else:
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
self.free_block_pool.extend(prev_blocks)
|
131
|
-
self.block_tables[batch_idx].fill_(self.empty_block)
|
132
|
-
|
133
|
-
# Get the start (s) and end (e) positions from cache_position and
|
134
|
-
# iterate over the cache positions to allocate necessary blocks
|
135
|
-
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
136
|
-
for position in range(s, e + 1, self.kvcache_block_size):
|
137
|
-
block_idx = position // self.kvcache_block_size
|
138
|
-
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
139
|
-
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
140
|
-
update_block(batch_idx, block_idx)
|
141
|
-
|
142
|
-
return self.block_tables[batch_idx]
|
143
|
-
|
144
|
-
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
145
|
-
else:
|
146
|
-
for b_idx in range(self.batch_size):
|
147
|
-
position = cache_position[b_idx][0].item()
|
148
|
-
block_idx = position // self.kvcache_block_size
|
149
|
-
update_block(b_idx, block_idx)
|
149
|
+
for b_idx in range(self.batch_size):
|
150
|
+
position = cache_position[b_idx][0].item()
|
151
|
+
block_idx = position // self.kvcache_block_size
|
152
|
+
update_block(b_idx, block_idx)
|
150
153
|
|
151
|
-
|
154
|
+
return replace_empty_block(self.block_tables)
|
152
155
|
|
153
156
|
def forward(
|
154
157
|
self,
|
@@ -380,14 +383,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
380
383
|
|
381
384
|
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
382
385
|
dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
block_tables = torch.zeros(
|
388
|
-
self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
|
389
|
-
).fill_(self.kvcache_num_blocks - 1)
|
390
|
-
free_block_pool = deque(x for x in range(self.kvcache_num_blocks - 1))
|
386
|
+
block_tables = torch.zeros(
|
387
|
+
self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
|
388
|
+
).fill_(-1)
|
389
|
+
free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
|
391
390
|
|
392
391
|
self.prefill_decoder = RBLNRuntimeModel(
|
393
392
|
runtime=self.model[0],
|
@@ -399,7 +398,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
399
398
|
block_tables=block_tables,
|
400
399
|
free_block_pool=free_block_pool,
|
401
400
|
kvcache_block_size=self.kvcache_block_size,
|
402
|
-
kvcache_num_blocks=self.kvcache_num_blocks,
|
403
401
|
vocab_size=self.config.vocab_size,
|
404
402
|
prefill_chunk_size=self.prefill_chunk_size,
|
405
403
|
max_seq_len=self.max_seq_len,
|
@@ -416,7 +414,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
416
414
|
block_tables=block_tables,
|
417
415
|
free_block_pool=free_block_pool,
|
418
416
|
kvcache_block_size=self.kvcache_block_size,
|
419
|
-
kvcache_num_blocks=self.kvcache_num_blocks,
|
420
417
|
use_attention_mask=self.use_attention_mask,
|
421
418
|
attn_impl=attn_impl,
|
422
419
|
)
|
@@ -569,6 +566,72 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
569
566
|
|
570
567
|
return compile_model(quantize_config=quantize_config)
|
571
568
|
|
569
|
+
@classmethod
|
570
|
+
def get_maximum_num_blocks(
|
571
|
+
cls,
|
572
|
+
config: PretrainedConfig,
|
573
|
+
tensor_parallel_size: int,
|
574
|
+
kvcache_block_size: int,
|
575
|
+
nbits_per_param: int,
|
576
|
+
n_model_params: int,
|
577
|
+
) -> int:
|
578
|
+
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
579
|
+
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
580
|
+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
581
|
+
vocab_size = config.vocab_size
|
582
|
+
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
583
|
+
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
584
|
+
|
585
|
+
TARGET_DRAM_LIMIT = int(tensor_parallel_size * 15.7 * 2**30) # 16GB # TODO(jongho): 더 정확한 값
|
586
|
+
|
587
|
+
def align(x: int, nbytes: int) -> int:
|
588
|
+
return int(math.ceil(x / nbytes) * nbytes)
|
589
|
+
|
590
|
+
def align_2MB(x: int) -> int:
|
591
|
+
return align(x, 2 * 1024 * 1024)
|
592
|
+
|
593
|
+
def get_kernel_size() -> int:
|
594
|
+
# TODO: Implement
|
595
|
+
lm_heads_params = align(vocab_size, 64) * hidden_size
|
596
|
+
lm_heads_nbytes = (
|
597
|
+
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
598
|
+
)
|
599
|
+
|
600
|
+
params = n_model_params - lm_heads_params
|
601
|
+
layer_nbytes = (
|
602
|
+
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
603
|
+
* num_layers
|
604
|
+
* tensor_parallel_size
|
605
|
+
)
|
606
|
+
|
607
|
+
return layer_nbytes + lm_heads_nbytes
|
608
|
+
|
609
|
+
available_dram = TARGET_DRAM_LIMIT - get_kernel_size()
|
610
|
+
|
611
|
+
buffer = 2**30 # 1GB
|
612
|
+
if tensor_parallel_size <= 2:
|
613
|
+
buffer /= 4
|
614
|
+
|
615
|
+
available_dram -= buffer
|
616
|
+
|
617
|
+
def get_nbytes_per_block() -> int:
|
618
|
+
return (
|
619
|
+
align_2MB(
|
620
|
+
kvcache_block_size
|
621
|
+
* head_dim
|
622
|
+
* math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
|
623
|
+
* 2 # (fp16)
|
624
|
+
)
|
625
|
+
* num_layers
|
626
|
+
* 2 # (k, v)
|
627
|
+
* tensor_parallel_size
|
628
|
+
)
|
629
|
+
|
630
|
+
nbytes_per_block = get_nbytes_per_block()
|
631
|
+
n_blocks = available_dram // nbytes_per_block
|
632
|
+
|
633
|
+
return n_blocks, nbytes_per_block
|
634
|
+
|
572
635
|
@classmethod
|
573
636
|
def _get_rbln_config(
|
574
637
|
cls,
|
@@ -622,8 +685,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
622
685
|
else:
|
623
686
|
rbln_kvcache_block_size = rbln_kvcache_partition_len
|
624
687
|
|
625
|
-
|
626
|
-
|
688
|
+
max_num_blocks, nbytes_per_block = cls.get_maximum_num_blocks(
|
689
|
+
config=model_config,
|
690
|
+
tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
|
691
|
+
kvcache_block_size=rbln_kvcache_block_size,
|
692
|
+
nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
|
693
|
+
n_model_params=rbln_kwargs["n_model_params"],
|
694
|
+
)
|
695
|
+
model_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
|
696
|
+
rbln_kvcache_num_blocks = min(model_num_blocks, max_num_blocks)
|
697
|
+
|
698
|
+
required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
|
699
|
+
if rbln_kvcache_num_blocks < required_blocks:
|
700
|
+
rbln_kvcache_num_blocks = required_blocks
|
701
|
+
|
702
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
|
703
|
+
|
704
|
+
if rbln_kvcache_num_blocks < rbln_batch_size:
|
705
|
+
raise RuntimeError(
|
706
|
+
f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
|
707
|
+
"Ensure the number of blocks is at least equal to the batch size."
|
708
|
+
)
|
627
709
|
|
628
710
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
629
711
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
@@ -723,6 +805,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
723
805
|
"kvcache_block_size": rbln_kvcache_block_size,
|
724
806
|
"attn_impl": rbln_attn_impl,
|
725
807
|
"kvcache_num_blocks": rbln_kvcache_num_blocks,
|
808
|
+
"model_num_blocks": model_num_blocks,
|
809
|
+
"max_num_blocks": max_num_blocks,
|
810
|
+
"nbytes_per_block": nbytes_per_block,
|
726
811
|
}
|
727
812
|
)
|
728
813
|
|
@@ -114,11 +114,9 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
114
114
|
|
115
115
|
# 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
|
116
116
|
batch_axis = torch.tensor(1, dtype=torch.int16)
|
117
|
-
|
118
|
-
cross_key_values, cross_kv, b_idx[0], batch_axis
|
119
|
-
)
|
117
|
+
enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
|
120
118
|
|
121
|
-
return
|
119
|
+
return enc_out
|
122
120
|
|
123
121
|
|
124
122
|
class Seq2SeqDecoderWrapper(nn.Module):
|
@@ -193,7 +191,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
193
191
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
194
192
|
|
195
193
|
# decode
|
196
|
-
lm_logits
|
194
|
+
lm_logits = self.conditional_generation(
|
197
195
|
input_ids=input_ids,
|
198
196
|
attention_mask=attention_mask,
|
199
197
|
encoder_attention_mask=encoder_attention_mask,
|
@@ -203,9 +201,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
203
201
|
block_tables=block_tables,
|
204
202
|
)
|
205
203
|
|
206
|
-
|
207
|
-
|
208
|
-
return outputs
|
204
|
+
return lm_logits
|
209
205
|
|
210
206
|
|
211
207
|
class Seq2SeqForConditionalGeneration(nn.Module):
|
@@ -250,7 +246,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
|
|
250
246
|
cache_position,
|
251
247
|
block_tables: Optional[torch.Tensor] = None,
|
252
248
|
):
|
253
|
-
hidden_states
|
249
|
+
hidden_states = self.decoder(
|
254
250
|
input_ids=input_ids,
|
255
251
|
attention_mask=attention_mask,
|
256
252
|
encoder_attention_mask=encoder_attention_mask,
|
@@ -265,7 +261,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
|
|
265
261
|
|
266
262
|
lm_logits = self.lm_head(hidden_states)
|
267
263
|
|
268
|
-
return lm_logits
|
264
|
+
return lm_logits
|
269
265
|
|
270
266
|
|
271
267
|
class Seq2SeqDecoder(torch.nn.Module):
|
@@ -326,11 +322,10 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
326
322
|
hidden_states = self.apply_position_embedding(hidden_states, cache_position)
|
327
323
|
|
328
324
|
# iterate decoder_layer
|
329
|
-
self_present_key_values = ()
|
330
325
|
for decoder_layer, self_past_key_value, cross_past_key_value in zip(
|
331
326
|
self.layers, self_past_key_values, cross_past_key_values
|
332
327
|
):
|
333
|
-
hidden_states
|
328
|
+
hidden_states = decoder_layer(
|
334
329
|
hidden_states,
|
335
330
|
attention_mask=attention_mask,
|
336
331
|
encoder_attention_mask=encoder_attention_mask,
|
@@ -339,12 +334,11 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
339
334
|
cache_position=cache_position,
|
340
335
|
block_tables=block_tables,
|
341
336
|
)
|
342
|
-
self_present_key_values += self_present_key_value
|
343
337
|
|
344
338
|
if self.final_layer_norm is not None:
|
345
339
|
hidden_states = self.final_layer_norm(hidden_states)
|
346
340
|
|
347
|
-
return hidden_states
|
341
|
+
return hidden_states
|
348
342
|
|
349
343
|
|
350
344
|
class Seq2SeqDecoderLayer(torch.nn.Module):
|
@@ -404,7 +398,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
404
398
|
# Self Attention Block
|
405
399
|
residual = hidden_states
|
406
400
|
hidden_states = self.pre_self_attn_layer_norm(hidden_states)
|
407
|
-
hidden_states
|
401
|
+
hidden_states = self.self_attn(
|
408
402
|
hidden_states=hidden_states,
|
409
403
|
past_key_value=self_past_key_value,
|
410
404
|
attention_mask=attention_mask,
|
@@ -429,7 +423,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
429
423
|
# Feed-Forward Block
|
430
424
|
hidden_states = self.ff_layer(hidden_states)
|
431
425
|
|
432
|
-
return hidden_states
|
426
|
+
return hidden_states
|
433
427
|
|
434
428
|
|
435
429
|
class Seq2SeqSelfAttention(nn.Module):
|
@@ -492,12 +486,11 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
492
486
|
if attention_mask is not None:
|
493
487
|
args.insert(3, attention_mask.unsqueeze(2))
|
494
488
|
|
495
|
-
attn_output
|
489
|
+
attn_output = self.attn_decode(*args)
|
496
490
|
|
497
491
|
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
498
492
|
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
499
493
|
|
500
494
|
attn_output = self.out_proj(attn_output)
|
501
|
-
present_key_value = (key_states, value_states)
|
502
495
|
|
503
|
-
return attn_output
|
496
|
+
return attn_output
|
@@ -88,7 +88,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
88
88
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
89
89
|
|
90
90
|
# decode
|
91
|
-
lm_logits
|
91
|
+
lm_logits = self.conditional_generation(
|
92
92
|
input_ids=input_ids,
|
93
93
|
attention_mask=attention_mask,
|
94
94
|
encoder_attention_mask=encoder_attention_mask,
|
@@ -97,9 +97,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
97
97
|
cache_position=cache_position,
|
98
98
|
)
|
99
99
|
|
100
|
-
|
101
|
-
|
102
|
-
return outputs
|
100
|
+
return lm_logits
|
103
101
|
|
104
102
|
|
105
103
|
class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
@@ -187,7 +185,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
187
185
|
key_states = self._shape(key_states, -1, bsz)
|
188
186
|
value_states = self._shape(value_states, -1, bsz)
|
189
187
|
|
190
|
-
attn_output
|
188
|
+
attn_output = self.attn_decode(
|
191
189
|
query_states,
|
192
190
|
key_states,
|
193
191
|
value_states,
|
@@ -204,9 +202,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
204
202
|
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
205
203
|
|
206
204
|
attn_output = self.out_proj(attn_output)
|
207
|
-
|
208
|
-
|
209
|
-
return attn_output, present_key_value
|
205
|
+
return attn_output
|
210
206
|
|
211
207
|
|
212
208
|
class T5CrossAttention(nn.Module):
|
@@ -25,7 +25,7 @@ from transformers.modeling_outputs import (
|
|
25
25
|
)
|
26
26
|
from transformers.utils import logging
|
27
27
|
|
28
|
-
from ....ops import register_rbln_custom_cache_update
|
28
|
+
from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
|
29
29
|
|
30
30
|
|
31
31
|
logger = logging.get_logger(__name__)
|
@@ -34,6 +34,7 @@ logger = logging.get_logger(__name__)
|
|
34
34
|
class WhisperWrapper:
|
35
35
|
def __init__(self, model, rbln_token_timestamps):
|
36
36
|
register_rbln_custom_cache_update()
|
37
|
+
register_rbln_custom_add_softmax_attention()
|
37
38
|
self.encoder = WhisperEncoderWrapper(model)
|
38
39
|
self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
|
39
40
|
|
@@ -77,9 +78,9 @@ class WhisperEncoderWrapper(torch.nn.Module):
|
|
77
78
|
# 3. update cross_attention's past_key_value to the device-dram for optimization.
|
78
79
|
bidx = torch.tensor(0, dtype=torch.int16)
|
79
80
|
axis = torch.tensor(1, dtype=torch.int16)
|
80
|
-
|
81
|
+
enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
|
81
82
|
|
82
|
-
return
|
83
|
+
return enc_output
|
83
84
|
|
84
85
|
|
85
86
|
class WhisperDecoderWrapper(torch.nn.Module):
|
@@ -118,7 +119,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
118
119
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
119
120
|
|
120
121
|
# Decode
|
121
|
-
sequence_output,
|
122
|
+
sequence_output, cross_attentions = self.decoder(
|
122
123
|
input_ids=decoder_input_ids,
|
123
124
|
attention_mask=decoder_attention_mask,
|
124
125
|
cache_position=cache_position,
|
@@ -127,9 +128,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
127
128
|
)
|
128
129
|
|
129
130
|
lm_logits = self.proj_out(sequence_output)
|
130
|
-
|
131
131
|
outputs = (lm_logits,)
|
132
|
-
outputs += self_present_key_values
|
133
132
|
|
134
133
|
if self.output_attentions:
|
135
134
|
# deocder's cross attention is used for token_timestamps
|
@@ -167,26 +166,23 @@ class WhisperDecoder(nn.Module):
|
|
167
166
|
# prepare casual_attn_mask
|
168
167
|
attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
|
169
168
|
|
170
|
-
self_present_key_values = ()
|
171
169
|
cross_attentions = ()
|
172
170
|
# iterate decoder_layer
|
173
171
|
for self_past_key_value, cross_past_key_value, decoder_layer in zip(
|
174
172
|
self_past_key_values, cross_past_key_values, self.layers
|
175
173
|
):
|
176
|
-
|
174
|
+
hidden_states, cross_attn_weights = decoder_layer(
|
177
175
|
hidden_states,
|
178
176
|
attention_mask=attention_mask,
|
179
177
|
self_past_key_value=self_past_key_value,
|
180
178
|
cross_past_key_value=cross_past_key_value,
|
181
179
|
cache_position=cache_position,
|
182
180
|
)
|
183
|
-
|
184
|
-
self_present_key_values += layer_outputs[1]
|
185
|
-
cross_attentions += (layer_outputs[2],)
|
181
|
+
cross_attentions += (cross_attn_weights,)
|
186
182
|
|
187
183
|
hidden_states = self.layer_norm(hidden_states)
|
188
184
|
|
189
|
-
return hidden_states,
|
185
|
+
return hidden_states, cross_attentions
|
190
186
|
|
191
187
|
|
192
188
|
class WhisperDecoderLayer(nn.Module):
|
@@ -213,7 +209,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
213
209
|
# Self Attention Block
|
214
210
|
residual = hidden_states
|
215
211
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
216
|
-
hidden_states
|
212
|
+
hidden_states = self.self_attn(
|
217
213
|
hidden_states=hidden_states,
|
218
214
|
past_key_value=self_past_key_value,
|
219
215
|
attention_mask=attention_mask,
|
@@ -224,7 +220,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
224
220
|
# Cross-Attention Block
|
225
221
|
residual = hidden_states
|
226
222
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
227
|
-
hidden_states, cross_attn_weights
|
223
|
+
hidden_states, cross_attn_weights = self.encoder_attn(
|
228
224
|
hidden_states=hidden_states,
|
229
225
|
past_key_value=cross_past_key_value,
|
230
226
|
)
|
@@ -237,7 +233,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
237
233
|
hidden_states = self.fc2(hidden_states)
|
238
234
|
hidden_states = residual + hidden_states
|
239
235
|
|
240
|
-
return hidden_states,
|
236
|
+
return hidden_states, cross_attn_weights
|
241
237
|
|
242
238
|
|
243
239
|
class WhisperAttention(nn.Module):
|
@@ -258,19 +254,8 @@ class WhisperAttention(nn.Module):
|
|
258
254
|
|
259
255
|
|
260
256
|
class WhisperSelfAttention(WhisperAttention):
|
261
|
-
def
|
262
|
-
self,
|
263
|
-
past_key_value: torch.Tensor,
|
264
|
-
key_states: torch.Tensor,
|
265
|
-
value_states: torch.Tensor,
|
266
|
-
cache_position: torch.Tensor,
|
267
|
-
):
|
268
|
-
s_idx = torch.tensor(cache_position, dtype=torch.int16)
|
269
|
-
axis = torch.tensor(2, dtype=torch.int16)
|
270
|
-
|
271
|
-
key_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[0], key_states, s_idx, axis)
|
272
|
-
value_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[1], value_states, s_idx, axis)
|
273
|
-
return key_states, value_states
|
257
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
258
|
+
return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
274
259
|
|
275
260
|
def forward(
|
276
261
|
self,
|
@@ -285,22 +270,27 @@ class WhisperSelfAttention(WhisperAttention):
|
|
285
270
|
|
286
271
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
287
272
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
288
|
-
key_states, value_states = self.rbln_cache_update(past_key_value, key_states, value_states, cache_position)
|
289
273
|
|
290
|
-
|
291
|
-
|
292
|
-
|
274
|
+
attn_output = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
|
275
|
+
query_states,
|
276
|
+
key_states,
|
277
|
+
value_states,
|
278
|
+
attention_mask.unsqueeze(2),
|
279
|
+
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
280
|
+
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
281
|
+
cache_position.expand(bsz, 1),
|
282
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
283
|
+
)
|
293
284
|
|
294
|
-
attn_output = torch.matmul(attn_weights, value_states)
|
295
285
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
296
286
|
attn_output = attn_output.transpose(1, 2)
|
297
287
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
298
288
|
attn_output = self.out_proj(attn_output)
|
299
289
|
|
300
|
-
return attn_output
|
290
|
+
return attn_output
|
301
291
|
|
302
292
|
|
303
|
-
class WhisperCrossAttention(
|
293
|
+
class WhisperCrossAttention(WhisperAttention):
|
304
294
|
def forward(
|
305
295
|
self,
|
306
296
|
hidden_states: torch.Tensor,
|
@@ -322,4 +312,4 @@ class WhisperCrossAttention(WhisperSelfAttention):
|
|
322
312
|
attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
|
323
313
|
attn_output = self.out_proj(attn_output)
|
324
314
|
|
325
|
-
return attn_output, attn_weights
|
315
|
+
return attn_output, attn_weights
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.3a4
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|
@@ -1,7 +1,7 @@
|
|
1
1
|
optimum/rbln/__init__.py,sha256=eHi15YM3989AcX52jka9rUmgAtlp1PHqMNwBEdOfuu8,6554
|
2
|
-
optimum/rbln/__version__.py,sha256=
|
3
|
-
optimum/rbln/modeling.py,sha256=
|
4
|
-
optimum/rbln/modeling_base.py,sha256=
|
2
|
+
optimum/rbln/__version__.py,sha256=MLlg_138GxyhciEP0ZB5dPN8vriXkicRnaZiwqygxOY,519
|
3
|
+
optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
|
4
|
+
optimum/rbln/modeling_base.py,sha256=Ow73GVJF1N5cDFO8_rgirtGj1wC-cXBDyqXHW5PCybA,22270
|
5
5
|
optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
|
6
6
|
optimum/rbln/diffusers/__init__.py,sha256=pOyoXv3-JRzTBSwPKbgLS9H6F2K9dJdReEmpGhcLQYU,3283
|
7
7
|
optimum/rbln/diffusers/modeling_diffusers.py,sha256=zqVNgH9oeOx2iNE7VsW_FinVf4s6G5Idyh4TKz7XJJg,21116
|
@@ -40,9 +40,9 @@ optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_x
|
|
40
40
|
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=3aB1Rw-OgKytQOHwOaShbEvq_XVHPOGvsGm8pstEmKU,930
|
41
41
|
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=MzVP1wscaO1sUIiBIPJqG6zuGyez9VUbA42-JSIm-mk,930
|
42
42
|
optimum/rbln/ops/__init__.py,sha256=TxOmsN0u3PmyK4Sb89qbiC4rePOlkvUT7Lm6wVoTnY0,941
|
43
|
-
optimum/rbln/ops/attn.py,sha256=
|
44
|
-
optimum/rbln/ops/flash_attn.py,sha256=
|
45
|
-
optimum/rbln/ops/kv_cache_update.py,sha256=
|
43
|
+
optimum/rbln/ops/attn.py,sha256=3EqU63Oj4zI4rLbkRycorsscXeD-IpKzt9N1MhkMa5o,10374
|
44
|
+
optimum/rbln/ops/flash_attn.py,sha256=wfyiCxDGf034IngzwRU160R7_DlKYpd-uWT0BDEGFks,3408
|
45
|
+
optimum/rbln/ops/kv_cache_update.py,sha256=pxf8kAptPaQF5xE8qItvmlFOq_sgim6ZERD7AVaOtec,3221
|
46
46
|
optimum/rbln/transformers/__init__.py,sha256=AGo3BqVIZrsOzYsQAnnQ25HCstTPBclrXbvvUxVMlqE,4255
|
47
47
|
optimum/rbln/transformers/modeling_alias.py,sha256=yx7FnZQWAnrWzivaO5hI7T6i-fyLzt2tMIXG2oDNbPo,1657
|
48
48
|
optimum/rbln/transformers/modeling_generic.py,sha256=aaZWsqVDCRvH03q-Wen7DMfLr7Gy-u-I0mTw0aYqWjk,18195
|
@@ -59,8 +59,8 @@ optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaM
|
|
59
59
|
optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
|
60
60
|
optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
|
61
61
|
optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
|
62
|
-
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=
|
63
|
-
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=
|
62
|
+
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=7OIKteJLKNxOLOg0w3lLOM7TxZovQn4jkglI9wRkrtQ,40609
|
63
|
+
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=W9HnxJoTz78Wc4X5Q3sMSHhMTSa7-9uQCFlnqNVozvA,38932
|
64
64
|
optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
|
65
65
|
optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
|
66
66
|
optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
|
@@ -91,16 +91,16 @@ optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz
|
|
91
91
|
optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
|
92
92
|
optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
|
93
93
|
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=NPfJf9Uk_bYOae7hXGHwteGiWH0va63Z-D93RmAMENg,17611
|
94
|
-
optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=
|
94
|
+
optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
|
95
95
|
optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
|
96
96
|
optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=nKRR3eH1EAu1YkKvhlqGyTrJXDRd-IWB5LOeG9jrcb4,16021
|
97
|
-
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=
|
97
|
+
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=AArCQhZRETVM583wlIRzMFOSYq7t2nzxaAeyhZxyxKk,9508
|
98
98
|
optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
|
99
99
|
optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
|
100
100
|
optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
|
101
101
|
optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
|
102
102
|
optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=0nBADNxE0A1ozBbRutTBvxpo_Y1qkOycT_zronkN-ZU,15840
|
103
|
-
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=
|
103
|
+
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=Yn6yFpmw6IQbWlnpIMAdEUsNF6huXgaKzGMUZbhSLdo,12572
|
104
104
|
optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
|
105
105
|
optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
|
106
106
|
optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -114,7 +114,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
|
|
114
114
|
optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
|
115
115
|
optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
|
116
116
|
optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
|
117
|
-
optimum_rbln-0.7.
|
118
|
-
optimum_rbln-0.7.
|
119
|
-
optimum_rbln-0.7.
|
120
|
-
optimum_rbln-0.7.
|
117
|
+
optimum_rbln-0.7.3a4.dist-info/METADATA,sha256=8VNTOVgsgFtcFUuZ9VEeRQfC2LEB60OFmW92hlJo8V8,5300
|
118
|
+
optimum_rbln-0.7.3a4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
119
|
+
optimum_rbln-0.7.3a4.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
120
|
+
optimum_rbln-0.7.3a4.dist-info/RECORD,,
|
File without changes
|
File without changes
|