optimum-rbln 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
optimum/rbln/__init__.py CHANGED
@@ -51,6 +51,7 @@ _import_structure = {
51
51
  "RBLNGPT2LMHeadModel",
52
52
  "RBLNWav2Vec2ForCTC",
53
53
  "RBLNLlamaForCausalLM",
54
+ "RBLNMidmLMHeadModel",
54
55
  "RBLNWhisperForConditionalGeneration",
55
56
  ],
56
57
  "diffusers": [
@@ -107,6 +108,7 @@ if TYPE_CHECKING:
107
108
  RBLNCLIPTextModelWithProjection,
108
109
  RBLNGPT2LMHeadModel,
109
110
  RBLNLlamaForCausalLM,
111
+ RBLNMidmLMHeadModel,
110
112
  RBLNWav2Vec2ForCTC,
111
113
  RBLNWhisperForConditionalGeneration,
112
114
  )
@@ -1 +1 @@
1
- __version__ = '0.1.1'
1
+ __version__ = '0.1.4'
@@ -99,7 +99,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
99
99
 
100
100
  model_type = "rbln_model"
101
101
  auto_model_class = AutoModel # feature extraction
102
- config_name = "model_index.json"
102
+ config_name = "config.json"
103
103
 
104
104
  def __init__(
105
105
  self,
@@ -490,7 +490,7 @@ class RBLNModel(RBLNBaseModel):
490
490
  preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
491
491
 
492
492
  # Get compilation arguments
493
- if rbln_config_kwargs.get("rbln_config", None) is None:
493
+ if (rbln_config := rbln_config_kwargs.pop("rbln_config", None)) is None:
494
494
  rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
495
495
 
496
496
  rbln_runtime_configs = list(rbln_config.values())
@@ -595,7 +595,7 @@ class RBLNModelForImageClassification(RBLNModel):
595
595
  rbln_image_size = processor.size["shortest_edge"]
596
596
  break
597
597
  if rbln_image_size is None:
598
- raise ValueError("`rbln_rbln_image_size` should be specified!")
598
+ raise ValueError("`rbln_image_size` should be specified!")
599
599
 
600
600
  if rbln_batch_size is None:
601
601
  rbln_batch_size = 1
@@ -35,6 +35,7 @@ _import_structure = {
35
35
  "RBLNWav2Vec2ForCTC",
36
36
  "RBLNWhisperForConditionalGeneration",
37
37
  "RBLNLlamaForCausalLM",
38
+ "RBLNMidmLMHeadModel",
38
39
  ],
39
40
  }
40
41
 
@@ -45,6 +46,7 @@ if TYPE_CHECKING:
45
46
  RBLNCLIPTextModelWithProjection,
46
47
  RBLNGPT2LMHeadModel,
47
48
  RBLNLlamaForCausalLM,
49
+ RBLNMidmLMHeadModel,
48
50
  RBLNWav2Vec2ForCTC,
49
51
  RBLNWhisperForConditionalGeneration,
50
52
  )
@@ -24,5 +24,6 @@
24
24
  from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
25
25
  from .gpt2 import RBLNGPT2LMHeadModel
26
26
  from .llama import RBLNLlamaForCausalLM
27
+ from .midm import RBLNMidmLMHeadModel
27
28
  from .wav2vec2 import RBLNWav2Vec2ForCTC
28
29
  from .whisper import RBLNWhisperForConditionalGeneration
@@ -36,7 +36,6 @@ from transformers.models.llama.modeling_llama import (
36
36
  LlamaForCausalLM,
37
37
  LlamaModel,
38
38
  LlamaRotaryEmbedding,
39
- repeat_kv,
40
39
  )
41
40
 
42
41
 
@@ -149,26 +148,41 @@ class _LlamaAttention(LlamaAttention):
149
148
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
150
149
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
151
150
 
151
+ # change to remove repeat
152
+ key_states = key_states.unsqueeze(2)
153
+ value_states = value_states.unsqueeze(2)
154
+ query_states = query_states.view(
155
+ bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
156
+ )
157
+
152
158
  if past_key_value is not None:
153
159
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
154
160
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
155
161
 
156
- key_states = repeat_kv(key_states, self.num_key_value_groups)
157
- value_states = repeat_kv(value_states, self.num_key_value_groups)
162
+ # change to remove repeat
163
+ # key_states = repeat_kv(key_states, self.num_key_value_groups)
164
+ # value_states = repeat_kv(value_states, self.num_key_value_groups)
158
165
 
159
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
166
+ # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
160
167
 
161
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
162
- raise ValueError(
163
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
164
- f" {attn_weights.size()}"
165
- )
168
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
169
+
170
+ # change to remove repeat
171
+ # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
172
+ # raise ValueError(
173
+ # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
174
+ # f" {attn_weights.size()}"
175
+ # )
166
176
 
167
177
  if attention_mask is not None:
168
178
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
169
179
  raise ValueError(
170
180
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
171
181
  )
182
+ else:
183
+ # change to remove repeat
184
+ attention_mask = attention_mask.unsqueeze(2)
185
+
172
186
  attn_weights = attn_weights + attention_mask
173
187
 
174
188
  # upcast attention to fp32
@@ -176,6 +190,9 @@ class _LlamaAttention(LlamaAttention):
176
190
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
177
191
  attn_output = torch.matmul(attn_weights, value_states)
178
192
 
193
+ # change to remove repeat
194
+ attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
195
+
179
196
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
180
197
  raise ValueError(
181
198
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
@@ -516,17 +533,32 @@ class RebelDynamicCache(DynamicCache):
516
533
  if len(self.key_cache) <= layer_idx:
517
534
  self.key_cache.append(key_states)
518
535
  self.value_cache.append(value_states)
536
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
519
537
  else:
520
- self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
521
- key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
538
+ # change to remove repeat
539
+ # self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
540
+ # key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
541
+ # )
542
+ # self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
543
+ # value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
544
+ # )
545
+ updated_key = (
546
+ self.key_cache[layer_idx]
547
+ .unsqueeze(2)
548
+ .slice_scatter(
549
+ key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
550
+ )
522
551
  )
523
- self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
524
- value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
552
+ updated_value = (
553
+ self.value_cache[layer_idx]
554
+ .unsqueeze(2)
555
+ .slice_scatter(
556
+ value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
557
+ )
525
558
  )
526
- # self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
527
- # self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
528
-
529
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
559
+ self.key_cache[layer_idx] = updated_key.squeeze(2)
560
+ self.value_cache[layer_idx] = updated_value.squeeze(2)
561
+ return updated_key, updated_value
530
562
 
531
563
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
532
564
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""