optimum-rbln 0.7.3a2__py3-none-any.whl → 0.7.3a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.3a2'
21
- __version_tuple__ = version_tuple = (0, 7, 3)
20
+ __version__ = version = '0.7.3a4'
21
+ __version_tuple__ = version_tuple = (0, 7, 3, 'a4')
optimum/rbln/modeling.py CHANGED
@@ -134,6 +134,9 @@ class RBLNModel(RBLNBaseModel):
134
134
  for preprocessor in preprocessors:
135
135
  preprocessor.save_pretrained(save_dir_path / subfolder)
136
136
 
137
+ # ad-hoc
138
+ rbln_kwargs["n_model_params"] = sum(p.numel() for p in model.parameters())
139
+
137
140
  # Get compilation arguments (e.g. input_info)
138
141
  rbln_config: RBLNConfig = cls.get_rbln_config(
139
142
  preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
@@ -282,6 +282,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
282
282
  **kwargs,
283
283
  )
284
284
 
285
+ @classmethod
286
+ def _check_compiled_models(
287
+ cls, compiled_models: Dict[str, rebel.RBLNCompiledModel], rbln_config: RBLNConfig, config: "PretrainedConfig"
288
+ ):
289
+ # check compiled model can create runtimes.
290
+ # this logic currently only works in LLM
291
+ # fail when LLM model using Paged Attention can't guarantee max sequence length
292
+ pass
293
+
285
294
  @classmethod
286
295
  def _from_compiled_models(
287
296
  cls,
@@ -295,6 +304,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
295
304
  ):
296
305
  if isinstance(model_save_dir, str):
297
306
  model_save_dir = Path(model_save_dir)
307
+
308
+ cls._check_compiled_models(compiled_models=rbln_compiled_models, rbln_config=rbln_config, config=config)
309
+
298
310
  # FIXME:: Should we convert it?
299
311
  compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
300
312
  rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
optimum/rbln/ops/attn.py CHANGED
@@ -28,7 +28,7 @@ else:
28
28
  def register_rbln_custom_paged_attention():
29
29
  torch.library.define(
30
30
  "rbln_custom_ops::paged_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
31
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
32
32
  )
33
33
 
34
34
  @torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
@@ -57,28 +57,17 @@ def register_rbln_custom_paged_attention():
57
57
  - block_size: [] - Number of tokens per block
58
58
 
59
59
  Returns:
60
- Tuple[Tensor, Tensor, Tensor]:
61
- - attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
62
- - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
63
- - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
60
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
64
61
  """
65
- return (
66
- q,
67
- torch.empty(*kcache.shape, device=kcache.device),
68
- torch.empty(*vcache.shape, device=vcache.device),
69
- )
62
+ return q
70
63
 
71
64
  @register_fake("rbln_custom_ops::paged_attn_decode")
72
65
  def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
73
- return (
74
- q,
75
- torch.empty(*kcache.shape, device=kcache.device),
76
- torch.empty(*vcache.shape, device=vcache.device),
77
- )
66
+ return q
78
67
 
79
68
  torch.library.define(
80
69
  "rbln_custom_ops::paged_attn_prefill",
81
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
70
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
82
71
  )
83
72
 
84
73
  @torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
@@ -105,23 +94,20 @@ def register_rbln_custom_paged_attention():
105
94
  - block_size: [] - Number of tokens per block
106
95
 
107
96
  Returns:
108
- Tuple[Tensor, Tensor, Tensor]:
109
- - attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
110
- - empty_kcache: Same shape as input kcache - Placeholder for compiler
111
- - empty_vcache: Same shape as input vcache - Placeholder for compiler
97
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
112
98
  """
113
- return q, kcache, vcache
99
+ return q
114
100
 
115
101
  @register_fake("rbln_custom_ops::paged_attn_prefill")
116
102
  def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
117
- return q, kcache, vcache
103
+ return q
118
104
 
119
105
 
120
106
  @lru_cache
121
107
  def register_rbln_custom_paged_causal_attention():
122
108
  torch.library.define(
123
109
  "rbln_custom_ops::paged_causal_attn_decode",
124
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
110
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
125
111
  )
126
112
 
127
113
  @torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
@@ -149,28 +135,17 @@ def register_rbln_custom_paged_causal_attention():
149
135
  - block_size: [] - Number of tokens per block
150
136
 
151
137
  Returns:
152
- Tuple[Tensor, Tensor, Tensor]:
153
- - attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
154
- - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
155
- - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
138
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
156
139
  """
157
- return (
158
- q,
159
- torch.empty(*kcache.shape, device=kcache.device),
160
- torch.empty(*vcache.shape, device=vcache.device),
161
- )
140
+ return q
162
141
 
163
142
  @register_fake("rbln_custom_ops::paged_causal_attn_decode")
164
143
  def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
165
- return (
166
- q,
167
- torch.empty(*kcache.shape, device=kcache.device),
168
- torch.empty(*vcache.shape, device=vcache.device),
169
- )
144
+ return q
170
145
 
171
146
  torch.library.define(
172
147
  "rbln_custom_ops::paged_causal_attn_prefill",
173
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
148
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
174
149
  )
175
150
 
176
151
  @torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
@@ -197,23 +172,20 @@ def register_rbln_custom_paged_causal_attention():
197
172
  - block_size: [] - Number of tokens per block
198
173
 
199
174
  Returns:
200
- Tuple[Tensor, Tensor, Tensor]:
201
- - attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
202
- - empty_kcache: Same shape as input kcache - Placeholder for compiler
203
- - empty_vcache: Same shape as input vcache - Placeholder for compiler
175
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
204
176
  """
205
- return q, kcache, vcache
177
+ return q
206
178
 
207
179
  @register_fake("rbln_custom_ops::paged_causal_attn_prefill")
208
180
  def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
209
- return q, kcache, vcache
181
+ return q
210
182
 
211
183
 
212
184
  @lru_cache
213
185
  def register_rbln_custom_add_softmax_attention():
214
186
  torch.library.define(
215
187
  "rbln_custom_ops::add_softmax_attn_decode",
216
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
188
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
217
189
  )
218
190
 
219
191
  @torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
@@ -240,21 +212,10 @@ def register_rbln_custom_add_softmax_attention():
240
212
  - scale: [] - Attention scale factor
241
213
 
242
214
  Returns:
243
- Tuple[Tensor, Tensor, Tensor]:
244
- - attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
245
- - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
246
- - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
215
+ Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
247
216
  """
248
- return (
249
- q,
250
- torch.empty(*kcache.shape, device=kcache.device),
251
- torch.empty(*vcache.shape, device=vcache.device),
252
- )
217
+ return q
253
218
 
254
219
  @register_fake("rbln_custom_ops::add_softmax_attn_decode")
255
220
  def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
256
- return (
257
- q,
258
- torch.empty(*kcache.shape, device=kcache.device),
259
- torch.empty(*vcache.shape, device=vcache.device),
260
- )
221
+ return q
@@ -28,71 +28,55 @@ else:
28
28
  def register_rbln_custom_paged_flash_attention():
29
29
  torch.library.define(
30
30
  "rbln_custom_ops::paged_flash_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
31
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
32
32
  )
33
33
 
34
34
  @torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
35
35
  def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
36
- return (
37
- q,
38
- torch.empty(*kcache.shape, device=kcache.device),
39
- torch.empty(*vcache.shape, device=vcache.device),
40
- )
36
+ return q
41
37
 
42
38
  @register_fake("rbln_custom_ops::paged_flash_attn_decode")
43
39
  def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
44
- return (
45
- q,
46
- torch.empty(*kcache.shape, device=kcache.device),
47
- torch.empty(*vcache.shape, device=vcache.device),
48
- )
40
+ return q
49
41
 
50
42
  torch.library.define(
51
43
  "rbln_custom_ops::paged_flash_attn_prefill",
52
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
44
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
53
45
  )
54
46
 
55
47
  @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
56
48
  def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
57
- return q, kcache, vcache
49
+ return q
58
50
 
59
51
  @register_fake("rbln_custom_ops::paged_flash_attn_prefill")
60
52
  def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
61
- return q, kcache, vcache
53
+ return q
62
54
 
63
55
 
64
56
  @lru_cache
65
57
  def register_rbln_custom_paged_flash_causal_attention():
66
58
  torch.library.define(
67
59
  "rbln_custom_ops::paged_flash_causal_attn_decode",
68
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
60
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
69
61
  )
70
62
 
71
63
  @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
72
64
  def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
73
- return (
74
- q,
75
- torch.empty(*kcache.shape, device=kcache.device),
76
- torch.empty(*vcache.shape, device=vcache.device),
77
- )
65
+ return q
78
66
 
79
67
  @register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
80
68
  def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
81
- return (
82
- q,
83
- torch.empty(*kcache.shape, device=kcache.device),
84
- torch.empty(*vcache.shape, device=vcache.device),
85
- )
69
+ return q
86
70
 
87
71
  torch.library.define(
88
72
  "rbln_custom_ops::paged_flash_causal_attn_prefill",
89
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
73
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
90
74
  )
91
75
 
92
76
  @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
93
77
  def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
94
- return q, kcache, vcache
78
+ return q
95
79
 
96
80
  @register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
97
81
  def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
98
- return q, kcache, vcache
82
+ return q
@@ -45,10 +45,10 @@ def register_rbln_custom_cache_update():
45
45
 
46
46
  # Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
47
47
  # This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
48
- updated_cache = cache.slice_scatter(state, dim=axis, start=s, end=e)
48
+ cache.slice_scatter(state, dim=axis, start=s, end=e)
49
49
 
50
- # Return the updated cache tensor.
51
- return updated_cache
50
+ # 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
51
+ return torch.empty([256])
52
52
 
53
53
  # Register a "fake" implementation of the "rbln_cache_update" operation.
54
54
  # This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
@@ -57,4 +57,4 @@ def register_rbln_custom_cache_update():
57
57
  # Return a tensor with the same shape as the input cache tensor.
58
58
  # This is a placeholder for the abstract implementation and does not perform any actual computation.
59
59
  # Like the actual implementation, the abstraction assumes in-place device-side updates.
60
- return torch.empty_like(cache)
60
+ return torch.empty([256])
@@ -281,7 +281,7 @@ class DecoderOnlyWrapper(nn.Module):
281
281
  _past_key_values.append(past_key_value)
282
282
  past_key_values = _past_key_values
283
283
 
284
- logit, present_key_values = self.causal_lm(
284
+ logit = self.causal_lm(
285
285
  input_ids=input_ids,
286
286
  inputs_embeds=inputs_embeds,
287
287
  attention_mask=attention_mask,
@@ -292,15 +292,7 @@ class DecoderOnlyWrapper(nn.Module):
292
292
  block_tables=block_tables,
293
293
  )
294
294
 
295
- # ((key, value)) * n_layer -> [key, value] * n_layer
296
- _present_key_values = ()
297
- for i in range(self.num_hidden_layers):
298
- key_states = present_key_values[i][0]
299
- value_states = present_key_values[i][1]
300
- _present_key_values = _present_key_values + (key_states, value_states)
301
- present_key_values = _present_key_values
302
-
303
- return logit, present_key_values
295
+ return logit
304
296
 
305
297
 
306
298
  class DecoderOnlyForCausalLM(nn.Module):
@@ -353,7 +345,7 @@ class DecoderOnlyForCausalLM(nn.Module):
353
345
  block_tables: Optional[torch.Tensor] = None,
354
346
  ):
355
347
  # outputs
356
- hidden_states, present_key_values = self.model(
348
+ hidden_states = self.model(
357
349
  input_ids=input_ids,
358
350
  inputs_embeds=inputs_embeds,
359
351
  attention_mask=attention_mask,
@@ -367,8 +359,7 @@ class DecoderOnlyForCausalLM(nn.Module):
367
359
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
368
360
 
369
361
  logits = self._original_mod.lm_head(hidden_states)
370
- output = (logits, present_key_values)
371
- return output
362
+ return logits
372
363
 
373
364
 
374
365
  class DecoderOnlyModel(nn.Module):
@@ -484,20 +475,19 @@ class DecoderOnlyModel(nn.Module):
484
475
  else:
485
476
  seq_positions = cache_position[:, :1]
486
477
 
487
- present_key_values = past_key_values
488
478
  for layer in self.layers:
489
- hidden_states, present_key_values = layer(
479
+ hidden_states = layer(
490
480
  hidden_states=hidden_states,
491
481
  attention_mask=attention_mask,
492
482
  seq_positions=seq_positions,
493
- past_key_values=present_key_values,
483
+ past_key_values=past_key_values,
494
484
  cos=cos,
495
485
  sin=sin,
496
486
  block_tables=block_tables,
497
487
  )
498
488
 
499
489
  hidden_states = self.get_last_layernorm()(hidden_states)
500
- return hidden_states, present_key_values
490
+ return hidden_states
501
491
 
502
492
 
503
493
  class DecoderOnlyLayer(nn.Module):
@@ -559,7 +549,7 @@ class DecoderOnlyLayer(nn.Module):
559
549
  residual = hidden_states
560
550
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
561
551
 
562
- hidden_states, present_key_values = self.self_attn(
552
+ hidden_states = self.self_attn(
563
553
  hidden_states=hidden_states,
564
554
  attention_mask=attention_mask,
565
555
  seq_positions=seq_positions,
@@ -576,7 +566,7 @@ class DecoderOnlyLayer(nn.Module):
576
566
  hidden_states = self._original_mod.mlp(hidden_states)
577
567
  hidden_states = residual + hidden_states
578
568
 
579
- return hidden_states, present_key_values
569
+ return hidden_states
580
570
 
581
571
 
582
572
  class DecoderOnlyAttention(nn.Module):
@@ -678,7 +668,7 @@ class DecoderOnlyAttention(nn.Module):
678
668
  if batch_size > 1 and self.phase == "prefill":
679
669
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
680
670
 
681
- attn_output, key_state, value_state = self.attention(
671
+ attn_output = self.attention(
682
672
  query_states,
683
673
  key_states,
684
674
  value_states,
@@ -690,12 +680,9 @@ class DecoderOnlyAttention(nn.Module):
690
680
  block_tables=block_tables,
691
681
  block_size=self.kvcache_block_size,
692
682
  )
693
- key_states = key_state
694
- value_states = value_state
695
683
 
696
684
  attn_outputs = self.o_proj(attn_output)
697
- past_key_values[self.layer_idx] = key_states, value_states
698
- return attn_outputs, past_key_values
685
+ return attn_outputs
699
686
 
700
687
 
701
688
  class AttentionOp(nn.Module):
@@ -733,7 +720,7 @@ class AttentionOp(nn.Module):
733
720
  scale: Scale applied to attn weights
734
721
 
735
722
  Returns:
736
- Tuple of (attention_output, key_state, value_state)
723
+ Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
737
724
  """
738
725
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
739
726
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
@@ -756,7 +743,7 @@ class AttentionOp(nn.Module):
756
743
 
757
744
  if self.phase == "decode":
758
745
  if self.use_attention_mask:
759
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_attn_decode(
746
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
760
747
  query_state,
761
748
  key_state,
762
749
  value_state,
@@ -769,7 +756,7 @@ class AttentionOp(nn.Module):
769
756
  block_size,
770
757
  )
771
758
  else:
772
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
759
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
773
760
  query_state,
774
761
  key_state,
775
762
  value_state,
@@ -783,7 +770,7 @@ class AttentionOp(nn.Module):
783
770
 
784
771
  else:
785
772
  if self.use_attention_mask:
786
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_attn_prefill(
773
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
787
774
  query_state,
788
775
  key_state,
789
776
  value_state,
@@ -796,7 +783,7 @@ class AttentionOp(nn.Module):
796
783
  block_size,
797
784
  )
798
785
  else:
799
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
786
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
800
787
  query_state,
801
788
  key_state,
802
789
  value_state,
@@ -812,7 +799,7 @@ class AttentionOp(nn.Module):
812
799
  attn_output = attn_output.transpose(1, 2).contiguous()
813
800
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
814
801
 
815
- return attn_output, key_state.squeeze(2), value_state.squeeze(2)
802
+ return attn_output
816
803
 
817
804
 
818
805
  def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
@@ -947,7 +934,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
947
934
  if cos is not None and sin is not None:
948
935
  query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
949
936
 
950
- attn_output, key_state, value_state = self.attention(
937
+ attn_output = self.attention(
951
938
  query_states,
952
939
  key_states,
953
940
  value_states,
@@ -959,13 +946,9 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
959
946
  block_tables=block_tables,
960
947
  kvcache_block_size=self.kvcache_block_size,
961
948
  )
962
- key_states = key_state
963
- value_states = value_state
964
949
 
965
950
  attn_outputs = self.o_proj(attn_output)
966
- past_key_values[self.layer_idx] = key_states, value_states
967
-
968
- return attn_outputs, past_key_values
951
+ return attn_outputs
969
952
 
970
953
 
971
954
  class FlashAttentionOp(AttentionOp):
@@ -1019,7 +1002,7 @@ class FlashAttentionOp(AttentionOp):
1019
1002
 
1020
1003
  if self.phase == "decode":
1021
1004
  if self.use_attention_mask:
1022
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1005
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1023
1006
  query_state,
1024
1007
  key_state,
1025
1008
  value_state,
@@ -1033,7 +1016,7 @@ class FlashAttentionOp(AttentionOp):
1033
1016
  self.kvcache_partition_size,
1034
1017
  )
1035
1018
  else:
1036
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1019
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1037
1020
  query_state,
1038
1021
  key_state,
1039
1022
  value_state,
@@ -1047,7 +1030,7 @@ class FlashAttentionOp(AttentionOp):
1047
1030
  )
1048
1031
  else:
1049
1032
  if self.use_attention_mask:
1050
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1033
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1051
1034
  query_state,
1052
1035
  key_state,
1053
1036
  value_state,
@@ -1061,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
1061
1044
  self.kvcache_partition_size,
1062
1045
  )
1063
1046
  else:
1064
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1047
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1065
1048
  query_state,
1066
1049
  key_state,
1067
1050
  value_state,
@@ -1079,4 +1062,4 @@ class FlashAttentionOp(AttentionOp):
1079
1062
  attn_output = attn_output.transpose(1, 2).contiguous()
1080
1063
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1081
1064
 
1082
- return attn_output, key_state, value_state
1065
+ return attn_output
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
+ import math
16
17
  from collections import deque
17
18
  from dataclasses import dataclass
18
19
  from pathlib import Path
@@ -54,7 +55,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
54
55
  block_tables: torch.Tensor,
55
56
  free_block_pool: Deque,
56
57
  kvcache_block_size: int,
57
- kvcache_num_blocks: int,
58
58
  use_attention_mask: bool,
59
59
  attn_impl: str,
60
60
  **kwargs: Any,
@@ -72,7 +72,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
72
72
  self.free_block_pool = free_block_pool
73
73
 
74
74
  self.kvcache_block_size = kvcache_block_size
75
- self.empty_block = kvcache_num_blocks - 1
75
+ self.empty_block = -1
76
76
  self.attn_impl = attn_impl
77
77
 
78
78
  if self.phase == "prefill":
@@ -97,58 +97,61 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
97
97
  torch.Tensor: Updated block tables.
98
98
  """
99
99
 
100
- def update_block(batch_idx, block_idx):
100
+ NO_BLOCKS_ERROR = (
101
+ "No memory blocks are available for allocation."
102
+ "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln."
103
+ "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html)."
104
+ "Using vllm-rbln should fix this issue and enhance inference performance."
105
+ )
106
+
107
+ def update_block(batch_idx: int, block_idx: int):
101
108
  """
102
- Helper function to update the block table for a given batch index and block index.
103
109
  If the block is empty (empty_block), allocates a block from the free_block_pool.
104
-
105
- Args:
106
- batch_idx (int): Batch index.
107
- block_idx (int): Block index.
108
-
109
- Raises:
110
- RuntimeError: Raised if no available blocks are found in the free_block_pool.
111
110
  """
112
111
  if self.block_tables[batch_idx][block_idx] == self.empty_block:
113
112
  if self.free_block_pool:
114
113
  block = self.free_block_pool.popleft()
115
114
  self.block_tables[batch_idx][block_idx] = block
116
115
  else:
117
- raise RuntimeError("Not available blocks")
116
+ raise RuntimeError(NO_BLOCKS_ERROR)
118
117
 
119
- if self.attn_impl == "eager":
120
- if self.phase == "prefill":
121
- return self.block_tables[batch_idx]
118
+ def replace_empty_block(block_tables: torch.Tensor):
119
+ """
120
+ Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
121
+ """
122
+ if not torch.any(block_tables == self.empty_block):
123
+ return block_tables.clone()
124
+ elif self.free_block_pool:
125
+ _free_block = self.free_block_pool[0]
126
+ return torch.where(block_tables == self.empty_block, _free_block, block_tables)
122
127
  else:
123
- return self.block_tables
124
- # Case for 'flash_attn' attention implementation
128
+ raise RuntimeError(NO_BLOCKS_ERROR)
129
+
130
+ if self.phase == "prefill":
131
+ # Track previously used blocks and return them to the free_block_pool and
132
+ # reset the current batch's block table to empty blocks
133
+ prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
134
+ self.free_block_pool.extend(prev_blocks)
135
+ self.block_tables[batch_idx].fill_(self.empty_block)
136
+
137
+ # Get the start (s) and end (e) positions from cache_position and
138
+ # iterate over the cache positions to allocate necessary blocks
139
+ s, e = cache_position[0][0].item(), cache_position[0][-1].item()
140
+ for position in range(s, e + 1, self.kvcache_block_size):
141
+ block_idx = position // self.kvcache_block_size
142
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
143
+ raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
144
+ update_block(batch_idx, block_idx)
145
+
146
+ return replace_empty_block(self.block_tables[batch_idx])
147
+ # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
125
148
  else:
126
- if self.phase == "prefill":
127
- # Track previously used blocks and return them to the free_block_pool and
128
- # reset the current batch's block table to empty blocks
129
- prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
130
- self.free_block_pool.extend(prev_blocks)
131
- self.block_tables[batch_idx].fill_(self.empty_block)
132
-
133
- # Get the start (s) and end (e) positions from cache_position and
134
- # iterate over the cache positions to allocate necessary blocks
135
- s, e = cache_position[0][0].item(), cache_position[0][-1].item()
136
- for position in range(s, e + 1, self.kvcache_block_size):
137
- block_idx = position // self.kvcache_block_size
138
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
139
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
140
- update_block(batch_idx, block_idx)
141
-
142
- return self.block_tables[batch_idx]
143
-
144
- # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
145
- else:
146
- for b_idx in range(self.batch_size):
147
- position = cache_position[b_idx][0].item()
148
- block_idx = position // self.kvcache_block_size
149
- update_block(b_idx, block_idx)
149
+ for b_idx in range(self.batch_size):
150
+ position = cache_position[b_idx][0].item()
151
+ block_idx = position // self.kvcache_block_size
152
+ update_block(b_idx, block_idx)
150
153
 
151
- return self.block_tables
154
+ return replace_empty_block(self.block_tables)
152
155
 
153
156
  def forward(
154
157
  self,
@@ -380,14 +383,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
380
383
 
381
384
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
382
385
  dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
383
- if attn_impl == "eager":
384
- block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).reshape(self.batch_size, 1)
385
- free_block_pool = None
386
- else:
387
- block_tables = torch.zeros(
388
- self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
389
- ).fill_(self.kvcache_num_blocks - 1)
390
- free_block_pool = deque(x for x in range(self.kvcache_num_blocks - 1))
386
+ block_tables = torch.zeros(
387
+ self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
388
+ ).fill_(-1)
389
+ free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
391
390
 
392
391
  self.prefill_decoder = RBLNRuntimeModel(
393
392
  runtime=self.model[0],
@@ -399,7 +398,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
399
398
  block_tables=block_tables,
400
399
  free_block_pool=free_block_pool,
401
400
  kvcache_block_size=self.kvcache_block_size,
402
- kvcache_num_blocks=self.kvcache_num_blocks,
403
401
  vocab_size=self.config.vocab_size,
404
402
  prefill_chunk_size=self.prefill_chunk_size,
405
403
  max_seq_len=self.max_seq_len,
@@ -416,7 +414,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
416
414
  block_tables=block_tables,
417
415
  free_block_pool=free_block_pool,
418
416
  kvcache_block_size=self.kvcache_block_size,
419
- kvcache_num_blocks=self.kvcache_num_blocks,
420
417
  use_attention_mask=self.use_attention_mask,
421
418
  attn_impl=attn_impl,
422
419
  )
@@ -569,6 +566,72 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
569
566
 
570
567
  return compile_model(quantize_config=quantize_config)
571
568
 
569
+ @classmethod
570
+ def get_maximum_num_blocks(
571
+ cls,
572
+ config: PretrainedConfig,
573
+ tensor_parallel_size: int,
574
+ kvcache_block_size: int,
575
+ nbits_per_param: int,
576
+ n_model_params: int,
577
+ ) -> int:
578
+ num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
579
+ num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
580
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
581
+ vocab_size = config.vocab_size
582
+ hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
583
+ num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
584
+
585
+ TARGET_DRAM_LIMIT = int(tensor_parallel_size * 15.7 * 2**30) # 16GB # TODO(jongho): 더 정확한 값
586
+
587
+ def align(x: int, nbytes: int) -> int:
588
+ return int(math.ceil(x / nbytes) * nbytes)
589
+
590
+ def align_2MB(x: int) -> int:
591
+ return align(x, 2 * 1024 * 1024)
592
+
593
+ def get_kernel_size() -> int:
594
+ # TODO: Implement
595
+ lm_heads_params = align(vocab_size, 64) * hidden_size
596
+ lm_heads_nbytes = (
597
+ align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
598
+ )
599
+
600
+ params = n_model_params - lm_heads_params
601
+ layer_nbytes = (
602
+ align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
603
+ * num_layers
604
+ * tensor_parallel_size
605
+ )
606
+
607
+ return layer_nbytes + lm_heads_nbytes
608
+
609
+ available_dram = TARGET_DRAM_LIMIT - get_kernel_size()
610
+
611
+ buffer = 2**30 # 1GB
612
+ if tensor_parallel_size <= 2:
613
+ buffer /= 4
614
+
615
+ available_dram -= buffer
616
+
617
+ def get_nbytes_per_block() -> int:
618
+ return (
619
+ align_2MB(
620
+ kvcache_block_size
621
+ * head_dim
622
+ * math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
623
+ * 2 # (fp16)
624
+ )
625
+ * num_layers
626
+ * 2 # (k, v)
627
+ * tensor_parallel_size
628
+ )
629
+
630
+ nbytes_per_block = get_nbytes_per_block()
631
+ n_blocks = available_dram // nbytes_per_block
632
+
633
+ return n_blocks, nbytes_per_block
634
+
572
635
  @classmethod
573
636
  def _get_rbln_config(
574
637
  cls,
@@ -622,8 +685,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
622
685
  else:
623
686
  rbln_kvcache_block_size = rbln_kvcache_partition_len
624
687
 
625
- # FIXME temporal num_blocks
626
- rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
688
+ max_num_blocks, nbytes_per_block = cls.get_maximum_num_blocks(
689
+ config=model_config,
690
+ tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
691
+ kvcache_block_size=rbln_kvcache_block_size,
692
+ nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
693
+ n_model_params=rbln_kwargs["n_model_params"],
694
+ )
695
+ model_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
696
+ rbln_kvcache_num_blocks = min(model_num_blocks, max_num_blocks)
697
+
698
+ required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
699
+ if rbln_kvcache_num_blocks < required_blocks:
700
+ rbln_kvcache_num_blocks = required_blocks
701
+
702
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
703
+
704
+ if rbln_kvcache_num_blocks < rbln_batch_size:
705
+ raise RuntimeError(
706
+ f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
707
+ "Ensure the number of blocks is at least equal to the batch size."
708
+ )
627
709
 
628
710
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
629
711
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
@@ -723,6 +805,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
723
805
  "kvcache_block_size": rbln_kvcache_block_size,
724
806
  "attn_impl": rbln_attn_impl,
725
807
  "kvcache_num_blocks": rbln_kvcache_num_blocks,
808
+ "model_num_blocks": model_num_blocks,
809
+ "max_num_blocks": max_num_blocks,
810
+ "nbytes_per_block": nbytes_per_block,
726
811
  }
727
812
  )
728
813
 
@@ -114,11 +114,9 @@ class Seq2SeqEncoderWrapper(nn.Module):
114
114
 
115
115
  # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
116
116
  batch_axis = torch.tensor(1, dtype=torch.int16)
117
- cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
118
- cross_key_values, cross_kv, b_idx[0], batch_axis
119
- )
117
+ enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
120
118
 
121
- return cross_key_values
119
+ return enc_out
122
120
 
123
121
 
124
122
  class Seq2SeqDecoderWrapper(nn.Module):
@@ -193,7 +191,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
193
191
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
194
192
 
195
193
  # decode
196
- lm_logits, self_present_key_values = self.conditional_generation(
194
+ lm_logits = self.conditional_generation(
197
195
  input_ids=input_ids,
198
196
  attention_mask=attention_mask,
199
197
  encoder_attention_mask=encoder_attention_mask,
@@ -203,9 +201,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
203
201
  block_tables=block_tables,
204
202
  )
205
203
 
206
- outputs = (lm_logits,) + self_present_key_values
207
-
208
- return outputs
204
+ return lm_logits
209
205
 
210
206
 
211
207
  class Seq2SeqForConditionalGeneration(nn.Module):
@@ -250,7 +246,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
250
246
  cache_position,
251
247
  block_tables: Optional[torch.Tensor] = None,
252
248
  ):
253
- hidden_states, self_present_key_values = self.decoder(
249
+ hidden_states = self.decoder(
254
250
  input_ids=input_ids,
255
251
  attention_mask=attention_mask,
256
252
  encoder_attention_mask=encoder_attention_mask,
@@ -265,7 +261,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
265
261
 
266
262
  lm_logits = self.lm_head(hidden_states)
267
263
 
268
- return lm_logits, self_present_key_values
264
+ return lm_logits
269
265
 
270
266
 
271
267
  class Seq2SeqDecoder(torch.nn.Module):
@@ -326,11 +322,10 @@ class Seq2SeqDecoder(torch.nn.Module):
326
322
  hidden_states = self.apply_position_embedding(hidden_states, cache_position)
327
323
 
328
324
  # iterate decoder_layer
329
- self_present_key_values = ()
330
325
  for decoder_layer, self_past_key_value, cross_past_key_value in zip(
331
326
  self.layers, self_past_key_values, cross_past_key_values
332
327
  ):
333
- hidden_states, self_present_key_value = decoder_layer(
328
+ hidden_states = decoder_layer(
334
329
  hidden_states,
335
330
  attention_mask=attention_mask,
336
331
  encoder_attention_mask=encoder_attention_mask,
@@ -339,12 +334,11 @@ class Seq2SeqDecoder(torch.nn.Module):
339
334
  cache_position=cache_position,
340
335
  block_tables=block_tables,
341
336
  )
342
- self_present_key_values += self_present_key_value
343
337
 
344
338
  if self.final_layer_norm is not None:
345
339
  hidden_states = self.final_layer_norm(hidden_states)
346
340
 
347
- return hidden_states, self_present_key_values
341
+ return hidden_states
348
342
 
349
343
 
350
344
  class Seq2SeqDecoderLayer(torch.nn.Module):
@@ -404,7 +398,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
404
398
  # Self Attention Block
405
399
  residual = hidden_states
406
400
  hidden_states = self.pre_self_attn_layer_norm(hidden_states)
407
- hidden_states, self_attn_past_key_value = self.self_attn(
401
+ hidden_states = self.self_attn(
408
402
  hidden_states=hidden_states,
409
403
  past_key_value=self_past_key_value,
410
404
  attention_mask=attention_mask,
@@ -429,7 +423,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
429
423
  # Feed-Forward Block
430
424
  hidden_states = self.ff_layer(hidden_states)
431
425
 
432
- return hidden_states, self_attn_past_key_value
426
+ return hidden_states
433
427
 
434
428
 
435
429
  class Seq2SeqSelfAttention(nn.Module):
@@ -492,12 +486,11 @@ class Seq2SeqSelfAttention(nn.Module):
492
486
  if attention_mask is not None:
493
487
  args.insert(3, attention_mask.unsqueeze(2))
494
488
 
495
- attn_output, key_states, value_states = self.attn_decode(*args)
489
+ attn_output = self.attn_decode(*args)
496
490
 
497
491
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
498
492
  attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
499
493
 
500
494
  attn_output = self.out_proj(attn_output)
501
- present_key_value = (key_states, value_states)
502
495
 
503
- return attn_output, present_key_value
496
+ return attn_output
@@ -88,7 +88,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
88
88
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
89
89
 
90
90
  # decode
91
- lm_logits, self_present_key_values = self.conditional_generation(
91
+ lm_logits = self.conditional_generation(
92
92
  input_ids=input_ids,
93
93
  attention_mask=attention_mask,
94
94
  encoder_attention_mask=encoder_attention_mask,
@@ -97,9 +97,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
97
97
  cache_position=cache_position,
98
98
  )
99
99
 
100
- outputs = (lm_logits,) + self_present_key_values
101
-
102
- return outputs
100
+ return lm_logits
103
101
 
104
102
 
105
103
  class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
@@ -187,7 +185,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
187
185
  key_states = self._shape(key_states, -1, bsz)
188
186
  value_states = self._shape(value_states, -1, bsz)
189
187
 
190
- attn_output, key_states, value_states = self.attn_decode(
188
+ attn_output = self.attn_decode(
191
189
  query_states,
192
190
  key_states,
193
191
  value_states,
@@ -204,9 +202,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
204
202
  attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
205
203
 
206
204
  attn_output = self.out_proj(attn_output)
207
- present_key_value = (key_states, value_states)
208
-
209
- return attn_output, present_key_value
205
+ return attn_output
210
206
 
211
207
 
212
208
  class T5CrossAttention(nn.Module):
@@ -25,7 +25,7 @@ from transformers.modeling_outputs import (
25
25
  )
26
26
  from transformers.utils import logging
27
27
 
28
- from ....ops import register_rbln_custom_cache_update
28
+ from ....ops import register_rbln_custom_add_softmax_attention, register_rbln_custom_cache_update
29
29
 
30
30
 
31
31
  logger = logging.get_logger(__name__)
@@ -34,6 +34,7 @@ logger = logging.get_logger(__name__)
34
34
  class WhisperWrapper:
35
35
  def __init__(self, model, rbln_token_timestamps):
36
36
  register_rbln_custom_cache_update()
37
+ register_rbln_custom_add_softmax_attention()
37
38
  self.encoder = WhisperEncoderWrapper(model)
38
39
  self.decoder = WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps)
39
40
 
@@ -77,9 +78,9 @@ class WhisperEncoderWrapper(torch.nn.Module):
77
78
  # 3. update cross_attention's past_key_value to the device-dram for optimization.
78
79
  bidx = torch.tensor(0, dtype=torch.int16)
79
80
  axis = torch.tensor(1, dtype=torch.int16)
80
- cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
81
+ enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
81
82
 
82
- return cross_key_values
83
+ return enc_output
83
84
 
84
85
 
85
86
  class WhisperDecoderWrapper(torch.nn.Module):
@@ -118,7 +119,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
118
119
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
119
120
 
120
121
  # Decode
121
- sequence_output, self_present_key_values, cross_attentions = self.decoder(
122
+ sequence_output, cross_attentions = self.decoder(
122
123
  input_ids=decoder_input_ids,
123
124
  attention_mask=decoder_attention_mask,
124
125
  cache_position=cache_position,
@@ -127,9 +128,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
127
128
  )
128
129
 
129
130
  lm_logits = self.proj_out(sequence_output)
130
-
131
131
  outputs = (lm_logits,)
132
- outputs += self_present_key_values
133
132
 
134
133
  if self.output_attentions:
135
134
  # deocder's cross attention is used for token_timestamps
@@ -167,26 +166,23 @@ class WhisperDecoder(nn.Module):
167
166
  # prepare casual_attn_mask
168
167
  attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
169
168
 
170
- self_present_key_values = ()
171
169
  cross_attentions = ()
172
170
  # iterate decoder_layer
173
171
  for self_past_key_value, cross_past_key_value, decoder_layer in zip(
174
172
  self_past_key_values, cross_past_key_values, self.layers
175
173
  ):
176
- layer_outputs = decoder_layer(
174
+ hidden_states, cross_attn_weights = decoder_layer(
177
175
  hidden_states,
178
176
  attention_mask=attention_mask,
179
177
  self_past_key_value=self_past_key_value,
180
178
  cross_past_key_value=cross_past_key_value,
181
179
  cache_position=cache_position,
182
180
  )
183
- hidden_states = layer_outputs[0]
184
- self_present_key_values += layer_outputs[1]
185
- cross_attentions += (layer_outputs[2],)
181
+ cross_attentions += (cross_attn_weights,)
186
182
 
187
183
  hidden_states = self.layer_norm(hidden_states)
188
184
 
189
- return hidden_states, self_present_key_values, cross_attentions
185
+ return hidden_states, cross_attentions
190
186
 
191
187
 
192
188
  class WhisperDecoderLayer(nn.Module):
@@ -213,7 +209,7 @@ class WhisperDecoderLayer(nn.Module):
213
209
  # Self Attention Block
214
210
  residual = hidden_states
215
211
  hidden_states = self.self_attn_layer_norm(hidden_states)
216
- hidden_states, _, self_present_key_value = self.self_attn(
212
+ hidden_states = self.self_attn(
217
213
  hidden_states=hidden_states,
218
214
  past_key_value=self_past_key_value,
219
215
  attention_mask=attention_mask,
@@ -224,7 +220,7 @@ class WhisperDecoderLayer(nn.Module):
224
220
  # Cross-Attention Block
225
221
  residual = hidden_states
226
222
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
227
- hidden_states, cross_attn_weights, cross_present_key_value = self.encoder_attn(
223
+ hidden_states, cross_attn_weights = self.encoder_attn(
228
224
  hidden_states=hidden_states,
229
225
  past_key_value=cross_past_key_value,
230
226
  )
@@ -237,7 +233,7 @@ class WhisperDecoderLayer(nn.Module):
237
233
  hidden_states = self.fc2(hidden_states)
238
234
  hidden_states = residual + hidden_states
239
235
 
240
- return hidden_states, self_present_key_value, cross_attn_weights
236
+ return hidden_states, cross_attn_weights
241
237
 
242
238
 
243
239
  class WhisperAttention(nn.Module):
@@ -258,19 +254,8 @@ class WhisperAttention(nn.Module):
258
254
 
259
255
 
260
256
  class WhisperSelfAttention(WhisperAttention):
261
- def rbln_cache_update(
262
- self,
263
- past_key_value: torch.Tensor,
264
- key_states: torch.Tensor,
265
- value_states: torch.Tensor,
266
- cache_position: torch.Tensor,
267
- ):
268
- s_idx = torch.tensor(cache_position, dtype=torch.int16)
269
- axis = torch.tensor(2, dtype=torch.int16)
270
-
271
- key_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[0], key_states, s_idx, axis)
272
- value_states = torch.ops.rbln_custom_ops.rbln_cache_update(past_key_value[1], value_states, s_idx, axis)
273
- return key_states, value_states
257
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
258
+ return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
274
259
 
275
260
  def forward(
276
261
  self,
@@ -285,22 +270,27 @@ class WhisperSelfAttention(WhisperAttention):
285
270
 
286
271
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
287
272
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
288
- key_states, value_states = self.rbln_cache_update(past_key_value, key_states, value_states, cache_position)
289
273
 
290
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
291
- attn_weights = attn_weights + attention_mask
292
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
274
+ attn_output = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
275
+ query_states,
276
+ key_states,
277
+ value_states,
278
+ attention_mask.unsqueeze(2),
279
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
280
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
281
+ cache_position.expand(bsz, 1),
282
+ torch.tensor(1.0, dtype=torch.float32), # scale
283
+ )
293
284
 
294
- attn_output = torch.matmul(attn_weights, value_states)
295
285
  attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
296
286
  attn_output = attn_output.transpose(1, 2)
297
287
  attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
298
288
  attn_output = self.out_proj(attn_output)
299
289
 
300
- return attn_output, attn_weights, (key_states, value_states)
290
+ return attn_output
301
291
 
302
292
 
303
- class WhisperCrossAttention(WhisperSelfAttention):
293
+ class WhisperCrossAttention(WhisperAttention):
304
294
  def forward(
305
295
  self,
306
296
  hidden_states: torch.Tensor,
@@ -322,4 +312,4 @@ class WhisperCrossAttention(WhisperSelfAttention):
322
312
  attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
323
313
  attn_output = self.out_proj(attn_output)
324
314
 
325
- return attn_output, attn_weights, (key_states, value_states)
315
+ return attn_output, attn_weights
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.3a2
3
+ Version: 0.7.3a4
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,7 +1,7 @@
1
1
  optimum/rbln/__init__.py,sha256=eHi15YM3989AcX52jka9rUmgAtlp1PHqMNwBEdOfuu8,6554
2
- optimum/rbln/__version__.py,sha256=bShBukYvw7AqWtLsut0yClygDEGsFRmxrXypqIeEXcQ,513
3
- optimum/rbln/modeling.py,sha256=3XE0IrCYbkjw9_Q9BFzZ_ri_Kyxw1g6iwfdohZB46-s,8289
4
- optimum/rbln/modeling_base.py,sha256=ELSPbjx7awBRM2SckkD-5gI3TIa01mfzz7gDRC1Pljk,21778
2
+ optimum/rbln/__version__.py,sha256=MLlg_138GxyhciEP0ZB5dPN8vriXkicRnaZiwqygxOY,519
3
+ optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
4
+ optimum/rbln/modeling_base.py,sha256=Ow73GVJF1N5cDFO8_rgirtGj1wC-cXBDyqXHW5PCybA,22270
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
6
6
  optimum/rbln/diffusers/__init__.py,sha256=pOyoXv3-JRzTBSwPKbgLS9H6F2K9dJdReEmpGhcLQYU,3283
7
7
  optimum/rbln/diffusers/modeling_diffusers.py,sha256=zqVNgH9oeOx2iNE7VsW_FinVf4s6G5Idyh4TKz7XJJg,21116
@@ -40,9 +40,9 @@ optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_x
40
40
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=3aB1Rw-OgKytQOHwOaShbEvq_XVHPOGvsGm8pstEmKU,930
41
41
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=MzVP1wscaO1sUIiBIPJqG6zuGyez9VUbA42-JSIm-mk,930
42
42
  optimum/rbln/ops/__init__.py,sha256=TxOmsN0u3PmyK4Sb89qbiC4rePOlkvUT7Lm6wVoTnY0,941
43
- optimum/rbln/ops/attn.py,sha256=LbJAmFtNj05i6BURfKV3KybsPItFe8w-YdSe5SuWkEc,12365
44
- optimum/rbln/ops/flash_attn.py,sha256=4shKNY13skPoYnbEsGrXDzgNwBIhHZEFrnUnWx1ESZU,4076
45
- optimum/rbln/ops/kv_cache_update.py,sha256=9W4WCO1Dtfy0u5i978JJRa7uLbqrfR2lHuoPynb07fw,3143
43
+ optimum/rbln/ops/attn.py,sha256=3EqU63Oj4zI4rLbkRycorsscXeD-IpKzt9N1MhkMa5o,10374
44
+ optimum/rbln/ops/flash_attn.py,sha256=wfyiCxDGf034IngzwRU160R7_DlKYpd-uWT0BDEGFks,3408
45
+ optimum/rbln/ops/kv_cache_update.py,sha256=pxf8kAptPaQF5xE8qItvmlFOq_sgim6ZERD7AVaOtec,3221
46
46
  optimum/rbln/transformers/__init__.py,sha256=AGo3BqVIZrsOzYsQAnnQ25HCstTPBclrXbvvUxVMlqE,4255
47
47
  optimum/rbln/transformers/modeling_alias.py,sha256=yx7FnZQWAnrWzivaO5hI7T6i-fyLzt2tMIXG2oDNbPo,1657
48
48
  optimum/rbln/transformers/modeling_generic.py,sha256=aaZWsqVDCRvH03q-Wen7DMfLr7Gy-u-I0mTw0aYqWjk,18195
@@ -59,8 +59,8 @@ optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaM
59
59
  optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
60
60
  optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
61
61
  optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
62
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=x8_xQ5aGXbadJyajpJQyfgxx4YPSj62VlmmGDMnC-1E,41819
63
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=dyl8tDBjfe5VfU1XbKAoZS7g7F90JTYVmMuz0HTmCoE,35345
62
+ optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=7OIKteJLKNxOLOg0w3lLOM7TxZovQn4jkglI9wRkrtQ,40609
63
+ optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=W9HnxJoTz78Wc4X5Q3sMSHhMTSa7-9uQCFlnqNVozvA,38932
64
64
  optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
65
65
  optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
66
66
  optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
@@ -91,16 +91,16 @@ optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz
91
91
  optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
92
92
  optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
93
93
  optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=NPfJf9Uk_bYOae7hXGHwteGiWH0va63Z-D93RmAMENg,17611
94
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=QXIGWSu9PsKWE3WhkgmBj3VeszqXIo2MPOwcrb54Tbs,19348
94
+ optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
95
95
  optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
96
96
  optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=nKRR3eH1EAu1YkKvhlqGyTrJXDRd-IWB5LOeG9jrcb4,16021
97
- optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=oCdmF4eCTayAVjx3c-SVpmhrjnWE92jh79dMIYCwotY,9690
97
+ optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=AArCQhZRETVM583wlIRzMFOSYq7t2nzxaAeyhZxyxKk,9508
98
98
  optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
99
99
  optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
100
100
  optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
101
101
  optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
102
102
  optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=0nBADNxE0A1ozBbRutTBvxpo_Y1qkOycT_zronkN-ZU,15840
103
- optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=eP3UgkwCRaaFjc5Jc4ZEiWxr3-L7oJx9KzpJ7eFkwUs,13158
103
+ optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=Yn6yFpmw6IQbWlnpIMAdEUsNF6huXgaKzGMUZbhSLdo,12572
104
104
  optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
105
105
  optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
106
106
  optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -114,7 +114,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
114
114
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
115
115
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
116
116
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
117
- optimum_rbln-0.7.3a2.dist-info/METADATA,sha256=C-IWumO-veJFZPHpF8wcOTOE0TCDzKU1Xk_ylaqrvPM,5300
118
- optimum_rbln-0.7.3a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
119
- optimum_rbln-0.7.3a2.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
120
- optimum_rbln-0.7.3a2.dist-info/RECORD,,
117
+ optimum_rbln-0.7.3a4.dist-info/METADATA,sha256=8VNTOVgsgFtcFUuZ9VEeRQfC2LEB60OFmW92hlJo8V8,5300
118
+ optimum_rbln-0.7.3a4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
119
+ optimum_rbln-0.7.3a4.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
120
+ optimum_rbln-0.7.3a4.dist-info/RECORD,,