optimum-rbln 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/modeling_base.py +3 -3
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +49 -17
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +759 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +126 -32
- optimum/rbln/transformers/models/midm/__init__.py +32 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +426 -0
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.4.dist-info}/METADATA +5 -4
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.4.dist-info}/RECORD +19 -11
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.4.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -51,6 +51,7 @@ _import_structure = {
|
|
51
51
|
"RBLNGPT2LMHeadModel",
|
52
52
|
"RBLNWav2Vec2ForCTC",
|
53
53
|
"RBLNLlamaForCausalLM",
|
54
|
+
"RBLNMidmLMHeadModel",
|
54
55
|
"RBLNWhisperForConditionalGeneration",
|
55
56
|
],
|
56
57
|
"diffusers": [
|
@@ -107,6 +108,7 @@ if TYPE_CHECKING:
|
|
107
108
|
RBLNCLIPTextModelWithProjection,
|
108
109
|
RBLNGPT2LMHeadModel,
|
109
110
|
RBLNLlamaForCausalLM,
|
111
|
+
RBLNMidmLMHeadModel,
|
110
112
|
RBLNWav2Vec2ForCTC,
|
111
113
|
RBLNWhisperForConditionalGeneration,
|
112
114
|
)
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.4'
|
optimum/rbln/modeling_base.py
CHANGED
@@ -99,7 +99,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
99
99
|
|
100
100
|
model_type = "rbln_model"
|
101
101
|
auto_model_class = AutoModel # feature extraction
|
102
|
-
config_name = "
|
102
|
+
config_name = "config.json"
|
103
103
|
|
104
104
|
def __init__(
|
105
105
|
self,
|
@@ -490,7 +490,7 @@ class RBLNModel(RBLNBaseModel):
|
|
490
490
|
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
491
491
|
|
492
492
|
# Get compilation arguments
|
493
|
-
if rbln_config_kwargs.
|
493
|
+
if (rbln_config := rbln_config_kwargs.pop("rbln_config", None)) is None:
|
494
494
|
rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
|
495
495
|
|
496
496
|
rbln_runtime_configs = list(rbln_config.values())
|
@@ -595,7 +595,7 @@ class RBLNModelForImageClassification(RBLNModel):
|
|
595
595
|
rbln_image_size = processor.size["shortest_edge"]
|
596
596
|
break
|
597
597
|
if rbln_image_size is None:
|
598
|
-
raise ValueError("`
|
598
|
+
raise ValueError("`rbln_image_size` should be specified!")
|
599
599
|
|
600
600
|
if rbln_batch_size is None:
|
601
601
|
rbln_batch_size = 1
|
@@ -35,6 +35,7 @@ _import_structure = {
|
|
35
35
|
"RBLNWav2Vec2ForCTC",
|
36
36
|
"RBLNWhisperForConditionalGeneration",
|
37
37
|
"RBLNLlamaForCausalLM",
|
38
|
+
"RBLNMidmLMHeadModel",
|
38
39
|
],
|
39
40
|
}
|
40
41
|
|
@@ -45,6 +46,7 @@ if TYPE_CHECKING:
|
|
45
46
|
RBLNCLIPTextModelWithProjection,
|
46
47
|
RBLNGPT2LMHeadModel,
|
47
48
|
RBLNLlamaForCausalLM,
|
49
|
+
RBLNMidmLMHeadModel,
|
48
50
|
RBLNWav2Vec2ForCTC,
|
49
51
|
RBLNWhisperForConditionalGeneration,
|
50
52
|
)
|
@@ -24,5 +24,6 @@
|
|
24
24
|
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
25
25
|
from .gpt2 import RBLNGPT2LMHeadModel
|
26
26
|
from .llama import RBLNLlamaForCausalLM
|
27
|
+
from .midm import RBLNMidmLMHeadModel
|
27
28
|
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
28
29
|
from .whisper import RBLNWhisperForConditionalGeneration
|
@@ -36,7 +36,6 @@ from transformers.models.llama.modeling_llama import (
|
|
36
36
|
LlamaForCausalLM,
|
37
37
|
LlamaModel,
|
38
38
|
LlamaRotaryEmbedding,
|
39
|
-
repeat_kv,
|
40
39
|
)
|
41
40
|
|
42
41
|
|
@@ -149,26 +148,41 @@ class _LlamaAttention(LlamaAttention):
|
|
149
148
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
150
149
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
151
150
|
|
151
|
+
# change to remove repeat
|
152
|
+
key_states = key_states.unsqueeze(2)
|
153
|
+
value_states = value_states.unsqueeze(2)
|
154
|
+
query_states = query_states.view(
|
155
|
+
bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
|
156
|
+
)
|
157
|
+
|
152
158
|
if past_key_value is not None:
|
153
159
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
154
160
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
155
161
|
|
156
|
-
|
157
|
-
|
162
|
+
# change to remove repeat
|
163
|
+
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
164
|
+
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
158
165
|
|
159
|
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
166
|
+
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
160
167
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
168
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
|
169
|
+
|
170
|
+
# change to remove repeat
|
171
|
+
# if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
172
|
+
# raise ValueError(
|
173
|
+
# f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
174
|
+
# f" {attn_weights.size()}"
|
175
|
+
# )
|
166
176
|
|
167
177
|
if attention_mask is not None:
|
168
178
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
169
179
|
raise ValueError(
|
170
180
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
171
181
|
)
|
182
|
+
else:
|
183
|
+
# change to remove repeat
|
184
|
+
attention_mask = attention_mask.unsqueeze(2)
|
185
|
+
|
172
186
|
attn_weights = attn_weights + attention_mask
|
173
187
|
|
174
188
|
# upcast attention to fp32
|
@@ -176,6 +190,9 @@ class _LlamaAttention(LlamaAttention):
|
|
176
190
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
177
191
|
attn_output = torch.matmul(attn_weights, value_states)
|
178
192
|
|
193
|
+
# change to remove repeat
|
194
|
+
attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
|
195
|
+
|
179
196
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
180
197
|
raise ValueError(
|
181
198
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
@@ -516,17 +533,32 @@ class RebelDynamicCache(DynamicCache):
|
|
516
533
|
if len(self.key_cache) <= layer_idx:
|
517
534
|
self.key_cache.append(key_states)
|
518
535
|
self.value_cache.append(value_states)
|
536
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
519
537
|
else:
|
520
|
-
|
521
|
-
|
538
|
+
# change to remove repeat
|
539
|
+
# self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
|
540
|
+
# key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
|
541
|
+
# )
|
542
|
+
# self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
|
543
|
+
# value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
|
544
|
+
# )
|
545
|
+
updated_key = (
|
546
|
+
self.key_cache[layer_idx]
|
547
|
+
.unsqueeze(2)
|
548
|
+
.slice_scatter(
|
549
|
+
key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
|
550
|
+
)
|
522
551
|
)
|
523
|
-
|
524
|
-
|
552
|
+
updated_value = (
|
553
|
+
self.value_cache[layer_idx]
|
554
|
+
.unsqueeze(2)
|
555
|
+
.slice_scatter(
|
556
|
+
value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
|
557
|
+
)
|
525
558
|
)
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
559
|
+
self.key_cache[layer_idx] = updated_key.squeeze(2)
|
560
|
+
self.value_cache[layer_idx] = updated_value.squeeze(2)
|
561
|
+
return updated_key, updated_value
|
530
562
|
|
531
563
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
532
564
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|