optimum-rbln 0.7.3a3__py3-none-any.whl → 0.7.3a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.3a3'
21
- __version_tuple__ = version_tuple = (0, 7, 3, 'a3')
20
+ __version__ = version = '0.7.3a4'
21
+ __version_tuple__ = version_tuple = (0, 7, 3, 'a4')
optimum/rbln/modeling.py CHANGED
@@ -134,6 +134,9 @@ class RBLNModel(RBLNBaseModel):
134
134
  for preprocessor in preprocessors:
135
135
  preprocessor.save_pretrained(save_dir_path / subfolder)
136
136
 
137
+ # ad-hoc
138
+ rbln_kwargs["n_model_params"] = sum(p.numel() for p in model.parameters())
139
+
137
140
  # Get compilation arguments (e.g. input_info)
138
141
  rbln_config: RBLNConfig = cls.get_rbln_config(
139
142
  preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
@@ -282,6 +282,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
282
282
  **kwargs,
283
283
  )
284
284
 
285
+ @classmethod
286
+ def _check_compiled_models(
287
+ cls, compiled_models: Dict[str, rebel.RBLNCompiledModel], rbln_config: RBLNConfig, config: "PretrainedConfig"
288
+ ):
289
+ # check compiled model can create runtimes.
290
+ # this logic currently only works in LLM
291
+ # fail when LLM model using Paged Attention can't guarantee max sequence length
292
+ pass
293
+
285
294
  @classmethod
286
295
  def _from_compiled_models(
287
296
  cls,
@@ -295,6 +304,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
295
304
  ):
296
305
  if isinstance(model_save_dir, str):
297
306
  model_save_dir = Path(model_save_dir)
307
+
308
+ cls._check_compiled_models(compiled_models=rbln_compiled_models, rbln_config=rbln_config, config=config)
309
+
298
310
  # FIXME:: Should we convert it?
299
311
  compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
300
312
  rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
optimum/rbln/ops/attn.py CHANGED
@@ -28,7 +28,7 @@ else:
28
28
  def register_rbln_custom_paged_attention():
29
29
  torch.library.define(
30
30
  "rbln_custom_ops::paged_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
31
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
32
32
  )
33
33
 
34
34
  @torch.library.impl("rbln_custom_ops::paged_attn_decode", "cpu")
@@ -57,28 +57,17 @@ def register_rbln_custom_paged_attention():
57
57
  - block_size: [] - Number of tokens per block
58
58
 
59
59
  Returns:
60
- Tuple[Tensor, Tensor, Tensor]:
61
- - attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
62
- - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
63
- - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
60
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
64
61
  """
65
- return (
66
- q,
67
- torch.empty(*kcache.shape, device=kcache.device),
68
- torch.empty(*vcache.shape, device=vcache.device),
69
- )
62
+ return q
70
63
 
71
64
  @register_fake("rbln_custom_ops::paged_attn_decode")
72
65
  def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
73
- return (
74
- q,
75
- torch.empty(*kcache.shape, device=kcache.device),
76
- torch.empty(*vcache.shape, device=vcache.device),
77
- )
66
+ return q
78
67
 
79
68
  torch.library.define(
80
69
  "rbln_custom_ops::paged_attn_prefill",
81
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
70
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
82
71
  )
83
72
 
84
73
  @torch.library.impl("rbln_custom_ops::paged_attn_prefill", "cpu")
@@ -105,23 +94,20 @@ def register_rbln_custom_paged_attention():
105
94
  - block_size: [] - Number of tokens per block
106
95
 
107
96
  Returns:
108
- Tuple[Tensor, Tensor, Tensor]:
109
- - attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
110
- - empty_kcache: Same shape as input kcache - Placeholder for compiler
111
- - empty_vcache: Same shape as input vcache - Placeholder for compiler
97
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
112
98
  """
113
- return q, kcache, vcache
99
+ return q
114
100
 
115
101
  @register_fake("rbln_custom_ops::paged_attn_prefill")
116
102
  def attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size):
117
- return q, kcache, vcache
103
+ return q
118
104
 
119
105
 
120
106
  @lru_cache
121
107
  def register_rbln_custom_paged_causal_attention():
122
108
  torch.library.define(
123
109
  "rbln_custom_ops::paged_causal_attn_decode",
124
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
110
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
125
111
  )
126
112
 
127
113
  @torch.library.impl("rbln_custom_ops::paged_causal_attn_decode", "cpu")
@@ -149,28 +135,17 @@ def register_rbln_custom_paged_causal_attention():
149
135
  - block_size: [] - Number of tokens per block
150
136
 
151
137
  Returns:
152
- Tuple[Tensor, Tensor, Tensor]:
153
- - attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
154
- - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
155
- - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
138
+ Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
156
139
  """
157
- return (
158
- q,
159
- torch.empty(*kcache.shape, device=kcache.device),
160
- torch.empty(*vcache.shape, device=vcache.device),
161
- )
140
+ return q
162
141
 
163
142
  @register_fake("rbln_custom_ops::paged_causal_attn_decode")
164
143
  def attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
165
- return (
166
- q,
167
- torch.empty(*kcache.shape, device=kcache.device),
168
- torch.empty(*vcache.shape, device=vcache.device),
169
- )
144
+ return q
170
145
 
171
146
  torch.library.define(
172
147
  "rbln_custom_ops::paged_causal_attn_prefill",
173
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
148
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
174
149
  )
175
150
 
176
151
  @torch.library.impl("rbln_custom_ops::paged_causal_attn_prefill", "cpu")
@@ -197,23 +172,20 @@ def register_rbln_custom_paged_causal_attention():
197
172
  - block_size: [] - Number of tokens per block
198
173
 
199
174
  Returns:
200
- Tuple[Tensor, Tensor, Tensor]:
201
- - attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
202
- - empty_kcache: Same shape as input kcache - Placeholder for compiler
203
- - empty_vcache: Same shape as input vcache - Placeholder for compiler
175
+ Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
204
176
  """
205
- return q, kcache, vcache
177
+ return q
206
178
 
207
179
  @register_fake("rbln_custom_ops::paged_causal_attn_prefill")
208
180
  def attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size):
209
- return q, kcache, vcache
181
+ return q
210
182
 
211
183
 
212
184
  @lru_cache
213
185
  def register_rbln_custom_add_softmax_attention():
214
186
  torch.library.define(
215
187
  "rbln_custom_ops::add_softmax_attn_decode",
216
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
188
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
217
189
  )
218
190
 
219
191
  @torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
@@ -240,21 +212,10 @@ def register_rbln_custom_add_softmax_attention():
240
212
  - scale: [] - Attention scale factor
241
213
 
242
214
  Returns:
243
- Tuple[Tensor, Tensor, Tensor]:
244
- - attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
245
- - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
246
- - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
215
+ Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
247
216
  """
248
- return (
249
- q,
250
- torch.empty(*kcache.shape, device=kcache.device),
251
- torch.empty(*vcache.shape, device=vcache.device),
252
- )
217
+ return q
253
218
 
254
219
  @register_fake("rbln_custom_ops::add_softmax_attn_decode")
255
220
  def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
256
- return (
257
- q,
258
- torch.empty(*kcache.shape, device=kcache.device),
259
- torch.empty(*vcache.shape, device=vcache.device),
260
- )
221
+ return q
@@ -28,71 +28,55 @@ else:
28
28
  def register_rbln_custom_paged_flash_attention():
29
29
  torch.library.define(
30
30
  "rbln_custom_ops::paged_flash_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
31
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
32
32
  )
33
33
 
34
34
  @torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
35
35
  def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
36
- return (
37
- q,
38
- torch.empty(*kcache.shape, device=kcache.device),
39
- torch.empty(*vcache.shape, device=vcache.device),
40
- )
36
+ return q
41
37
 
42
38
  @register_fake("rbln_custom_ops::paged_flash_attn_decode")
43
39
  def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
44
- return (
45
- q,
46
- torch.empty(*kcache.shape, device=kcache.device),
47
- torch.empty(*vcache.shape, device=vcache.device),
48
- )
40
+ return q
49
41
 
50
42
  torch.library.define(
51
43
  "rbln_custom_ops::paged_flash_attn_prefill",
52
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
44
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
53
45
  )
54
46
 
55
47
  @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
56
48
  def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
57
- return q, kcache, vcache
49
+ return q
58
50
 
59
51
  @register_fake("rbln_custom_ops::paged_flash_attn_prefill")
60
52
  def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
61
- return q, kcache, vcache
53
+ return q
62
54
 
63
55
 
64
56
  @lru_cache
65
57
  def register_rbln_custom_paged_flash_causal_attention():
66
58
  torch.library.define(
67
59
  "rbln_custom_ops::paged_flash_causal_attn_decode",
68
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
60
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
69
61
  )
70
62
 
71
63
  @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
72
64
  def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
73
- return (
74
- q,
75
- torch.empty(*kcache.shape, device=kcache.device),
76
- torch.empty(*vcache.shape, device=vcache.device),
77
- )
65
+ return q
78
66
 
79
67
  @register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
80
68
  def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
81
- return (
82
- q,
83
- torch.empty(*kcache.shape, device=kcache.device),
84
- torch.empty(*vcache.shape, device=vcache.device),
85
- )
69
+ return q
86
70
 
87
71
  torch.library.define(
88
72
  "rbln_custom_ops::paged_flash_causal_attn_prefill",
89
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor[]",
73
+ "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
90
74
  )
91
75
 
92
76
  @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
93
77
  def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
94
- return q, kcache, vcache
78
+ return q
95
79
 
96
80
  @register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
97
81
  def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
98
- return q, kcache, vcache
82
+ return q
@@ -45,10 +45,10 @@ def register_rbln_custom_cache_update():
45
45
 
46
46
  # Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
47
47
  # This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
48
- updated_cache = cache.slice_scatter(state, dim=axis, start=s, end=e)
48
+ cache.slice_scatter(state, dim=axis, start=s, end=e)
49
49
 
50
- # Return the updated cache tensor.
51
- return updated_cache
50
+ # 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
51
+ return torch.empty([256])
52
52
 
53
53
  # Register a "fake" implementation of the "rbln_cache_update" operation.
54
54
  # This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
@@ -57,4 +57,4 @@ def register_rbln_custom_cache_update():
57
57
  # Return a tensor with the same shape as the input cache tensor.
58
58
  # This is a placeholder for the abstract implementation and does not perform any actual computation.
59
59
  # Like the actual implementation, the abstraction assumes in-place device-side updates.
60
- return torch.empty_like(cache)
60
+ return torch.empty([256])
@@ -281,7 +281,7 @@ class DecoderOnlyWrapper(nn.Module):
281
281
  _past_key_values.append(past_key_value)
282
282
  past_key_values = _past_key_values
283
283
 
284
- logit, present_key_values = self.causal_lm(
284
+ logit = self.causal_lm(
285
285
  input_ids=input_ids,
286
286
  inputs_embeds=inputs_embeds,
287
287
  attention_mask=attention_mask,
@@ -292,15 +292,7 @@ class DecoderOnlyWrapper(nn.Module):
292
292
  block_tables=block_tables,
293
293
  )
294
294
 
295
- # ((key, value)) * n_layer -> [key, value] * n_layer
296
- _present_key_values = ()
297
- for i in range(self.num_hidden_layers):
298
- key_states = present_key_values[i][0]
299
- value_states = present_key_values[i][1]
300
- _present_key_values = _present_key_values + (key_states, value_states)
301
- present_key_values = _present_key_values
302
-
303
- return logit, present_key_values
295
+ return logit
304
296
 
305
297
 
306
298
  class DecoderOnlyForCausalLM(nn.Module):
@@ -353,7 +345,7 @@ class DecoderOnlyForCausalLM(nn.Module):
353
345
  block_tables: Optional[torch.Tensor] = None,
354
346
  ):
355
347
  # outputs
356
- hidden_states, present_key_values = self.model(
348
+ hidden_states = self.model(
357
349
  input_ids=input_ids,
358
350
  inputs_embeds=inputs_embeds,
359
351
  attention_mask=attention_mask,
@@ -367,8 +359,7 @@ class DecoderOnlyForCausalLM(nn.Module):
367
359
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
368
360
 
369
361
  logits = self._original_mod.lm_head(hidden_states)
370
- output = (logits, present_key_values)
371
- return output
362
+ return logits
372
363
 
373
364
 
374
365
  class DecoderOnlyModel(nn.Module):
@@ -484,20 +475,19 @@ class DecoderOnlyModel(nn.Module):
484
475
  else:
485
476
  seq_positions = cache_position[:, :1]
486
477
 
487
- present_key_values = past_key_values
488
478
  for layer in self.layers:
489
- hidden_states, present_key_values = layer(
479
+ hidden_states = layer(
490
480
  hidden_states=hidden_states,
491
481
  attention_mask=attention_mask,
492
482
  seq_positions=seq_positions,
493
- past_key_values=present_key_values,
483
+ past_key_values=past_key_values,
494
484
  cos=cos,
495
485
  sin=sin,
496
486
  block_tables=block_tables,
497
487
  )
498
488
 
499
489
  hidden_states = self.get_last_layernorm()(hidden_states)
500
- return hidden_states, present_key_values
490
+ return hidden_states
501
491
 
502
492
 
503
493
  class DecoderOnlyLayer(nn.Module):
@@ -559,7 +549,7 @@ class DecoderOnlyLayer(nn.Module):
559
549
  residual = hidden_states
560
550
  hidden_states = self.get_pre_attention_layernorm()(hidden_states)
561
551
 
562
- hidden_states, present_key_values = self.self_attn(
552
+ hidden_states = self.self_attn(
563
553
  hidden_states=hidden_states,
564
554
  attention_mask=attention_mask,
565
555
  seq_positions=seq_positions,
@@ -576,7 +566,7 @@ class DecoderOnlyLayer(nn.Module):
576
566
  hidden_states = self._original_mod.mlp(hidden_states)
577
567
  hidden_states = residual + hidden_states
578
568
 
579
- return hidden_states, present_key_values
569
+ return hidden_states
580
570
 
581
571
 
582
572
  class DecoderOnlyAttention(nn.Module):
@@ -678,7 +668,7 @@ class DecoderOnlyAttention(nn.Module):
678
668
  if batch_size > 1 and self.phase == "prefill":
679
669
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
680
670
 
681
- attn_output, key_state, value_state = self.attention(
671
+ attn_output = self.attention(
682
672
  query_states,
683
673
  key_states,
684
674
  value_states,
@@ -690,12 +680,9 @@ class DecoderOnlyAttention(nn.Module):
690
680
  block_tables=block_tables,
691
681
  block_size=self.kvcache_block_size,
692
682
  )
693
- key_states = key_state
694
- value_states = value_state
695
683
 
696
684
  attn_outputs = self.o_proj(attn_output)
697
- past_key_values[self.layer_idx] = key_states, value_states
698
- return attn_outputs, past_key_values
685
+ return attn_outputs
699
686
 
700
687
 
701
688
  class AttentionOp(nn.Module):
@@ -733,7 +720,7 @@ class AttentionOp(nn.Module):
733
720
  scale: Scale applied to attn weights
734
721
 
735
722
  Returns:
736
- Tuple of (attention_output, key_state, value_state)
723
+ Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
737
724
  """
738
725
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
739
726
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
@@ -756,7 +743,7 @@ class AttentionOp(nn.Module):
756
743
 
757
744
  if self.phase == "decode":
758
745
  if self.use_attention_mask:
759
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_attn_decode(
746
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
760
747
  query_state,
761
748
  key_state,
762
749
  value_state,
@@ -769,7 +756,7 @@ class AttentionOp(nn.Module):
769
756
  block_size,
770
757
  )
771
758
  else:
772
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
759
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
773
760
  query_state,
774
761
  key_state,
775
762
  value_state,
@@ -783,7 +770,7 @@ class AttentionOp(nn.Module):
783
770
 
784
771
  else:
785
772
  if self.use_attention_mask:
786
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_attn_prefill(
773
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
787
774
  query_state,
788
775
  key_state,
789
776
  value_state,
@@ -796,7 +783,7 @@ class AttentionOp(nn.Module):
796
783
  block_size,
797
784
  )
798
785
  else:
799
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
786
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
800
787
  query_state,
801
788
  key_state,
802
789
  value_state,
@@ -812,7 +799,7 @@ class AttentionOp(nn.Module):
812
799
  attn_output = attn_output.transpose(1, 2).contiguous()
813
800
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
814
801
 
815
- return attn_output, key_state.squeeze(2), value_state.squeeze(2)
802
+ return attn_output
816
803
 
817
804
 
818
805
  def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
@@ -947,7 +934,7 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
947
934
  if cos is not None and sin is not None:
948
935
  query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
949
936
 
950
- attn_output, key_state, value_state = self.attention(
937
+ attn_output = self.attention(
951
938
  query_states,
952
939
  key_states,
953
940
  value_states,
@@ -959,13 +946,9 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
959
946
  block_tables=block_tables,
960
947
  kvcache_block_size=self.kvcache_block_size,
961
948
  )
962
- key_states = key_state
963
- value_states = value_state
964
949
 
965
950
  attn_outputs = self.o_proj(attn_output)
966
- past_key_values[self.layer_idx] = key_states, value_states
967
-
968
- return attn_outputs, past_key_values
951
+ return attn_outputs
969
952
 
970
953
 
971
954
  class FlashAttentionOp(AttentionOp):
@@ -1019,7 +1002,7 @@ class FlashAttentionOp(AttentionOp):
1019
1002
 
1020
1003
  if self.phase == "decode":
1021
1004
  if self.use_attention_mask:
1022
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1005
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
1023
1006
  query_state,
1024
1007
  key_state,
1025
1008
  value_state,
@@ -1033,7 +1016,7 @@ class FlashAttentionOp(AttentionOp):
1033
1016
  self.kvcache_partition_size,
1034
1017
  )
1035
1018
  else:
1036
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1019
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
1037
1020
  query_state,
1038
1021
  key_state,
1039
1022
  value_state,
@@ -1047,7 +1030,7 @@ class FlashAttentionOp(AttentionOp):
1047
1030
  )
1048
1031
  else:
1049
1032
  if self.use_attention_mask:
1050
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1033
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
1051
1034
  query_state,
1052
1035
  key_state,
1053
1036
  value_state,
@@ -1061,7 +1044,7 @@ class FlashAttentionOp(AttentionOp):
1061
1044
  self.kvcache_partition_size,
1062
1045
  )
1063
1046
  else:
1064
- attn_output, key_state, value_state = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1047
+ attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
1065
1048
  query_state,
1066
1049
  key_state,
1067
1050
  value_state,
@@ -1079,4 +1062,4 @@ class FlashAttentionOp(AttentionOp):
1079
1062
  attn_output = attn_output.transpose(1, 2).contiguous()
1080
1063
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1081
1064
 
1082
- return attn_output, key_state, value_state
1065
+ return attn_output
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
+ import math
16
17
  from collections import deque
17
18
  from dataclasses import dataclass
18
19
  from pathlib import Path
@@ -565,6 +566,72 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
565
566
 
566
567
  return compile_model(quantize_config=quantize_config)
567
568
 
569
+ @classmethod
570
+ def get_maximum_num_blocks(
571
+ cls,
572
+ config: PretrainedConfig,
573
+ tensor_parallel_size: int,
574
+ kvcache_block_size: int,
575
+ nbits_per_param: int,
576
+ n_model_params: int,
577
+ ) -> int:
578
+ num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
579
+ num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
580
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
581
+ vocab_size = config.vocab_size
582
+ hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
583
+ num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
584
+
585
+ TARGET_DRAM_LIMIT = int(tensor_parallel_size * 15.7 * 2**30) # 16GB # TODO(jongho): 더 정확한 값
586
+
587
+ def align(x: int, nbytes: int) -> int:
588
+ return int(math.ceil(x / nbytes) * nbytes)
589
+
590
+ def align_2MB(x: int) -> int:
591
+ return align(x, 2 * 1024 * 1024)
592
+
593
+ def get_kernel_size() -> int:
594
+ # TODO: Implement
595
+ lm_heads_params = align(vocab_size, 64) * hidden_size
596
+ lm_heads_nbytes = (
597
+ align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
598
+ )
599
+
600
+ params = n_model_params - lm_heads_params
601
+ layer_nbytes = (
602
+ align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
603
+ * num_layers
604
+ * tensor_parallel_size
605
+ )
606
+
607
+ return layer_nbytes + lm_heads_nbytes
608
+
609
+ available_dram = TARGET_DRAM_LIMIT - get_kernel_size()
610
+
611
+ buffer = 2**30 # 1GB
612
+ if tensor_parallel_size <= 2:
613
+ buffer /= 4
614
+
615
+ available_dram -= buffer
616
+
617
+ def get_nbytes_per_block() -> int:
618
+ return (
619
+ align_2MB(
620
+ kvcache_block_size
621
+ * head_dim
622
+ * math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
623
+ * 2 # (fp16)
624
+ )
625
+ * num_layers
626
+ * 2 # (k, v)
627
+ * tensor_parallel_size
628
+ )
629
+
630
+ nbytes_per_block = get_nbytes_per_block()
631
+ n_blocks = available_dram // nbytes_per_block
632
+
633
+ return n_blocks, nbytes_per_block
634
+
568
635
  @classmethod
569
636
  def _get_rbln_config(
570
637
  cls,
@@ -618,8 +685,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
618
685
  else:
619
686
  rbln_kvcache_block_size = rbln_kvcache_partition_len
620
687
 
621
- # FIXME temporal num_blocks
622
- rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
688
+ max_num_blocks, nbytes_per_block = cls.get_maximum_num_blocks(
689
+ config=model_config,
690
+ tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
691
+ kvcache_block_size=rbln_kvcache_block_size,
692
+ nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
693
+ n_model_params=rbln_kwargs["n_model_params"],
694
+ )
695
+ model_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
696
+ rbln_kvcache_num_blocks = min(model_num_blocks, max_num_blocks)
697
+
698
+ required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
699
+ if rbln_kvcache_num_blocks < required_blocks:
700
+ rbln_kvcache_num_blocks = required_blocks
701
+
702
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
703
+
704
+ if rbln_kvcache_num_blocks < rbln_batch_size:
705
+ raise RuntimeError(
706
+ f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
707
+ "Ensure the number of blocks is at least equal to the batch size."
708
+ )
623
709
 
624
710
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
625
711
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
@@ -719,6 +805,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
719
805
  "kvcache_block_size": rbln_kvcache_block_size,
720
806
  "attn_impl": rbln_attn_impl,
721
807
  "kvcache_num_blocks": rbln_kvcache_num_blocks,
808
+ "model_num_blocks": model_num_blocks,
809
+ "max_num_blocks": max_num_blocks,
810
+ "nbytes_per_block": nbytes_per_block,
722
811
  }
723
812
  )
724
813
 
@@ -114,11 +114,9 @@ class Seq2SeqEncoderWrapper(nn.Module):
114
114
 
115
115
  # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
116
116
  batch_axis = torch.tensor(1, dtype=torch.int16)
117
- cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
118
- cross_key_values, cross_kv, b_idx[0], batch_axis
119
- )
117
+ enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
120
118
 
121
- return cross_key_values
119
+ return enc_out
122
120
 
123
121
 
124
122
  class Seq2SeqDecoderWrapper(nn.Module):
@@ -193,7 +191,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
193
191
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
194
192
 
195
193
  # decode
196
- lm_logits, self_present_key_values = self.conditional_generation(
194
+ lm_logits = self.conditional_generation(
197
195
  input_ids=input_ids,
198
196
  attention_mask=attention_mask,
199
197
  encoder_attention_mask=encoder_attention_mask,
@@ -203,9 +201,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
203
201
  block_tables=block_tables,
204
202
  )
205
203
 
206
- outputs = (lm_logits,) + self_present_key_values
207
-
208
- return outputs
204
+ return lm_logits
209
205
 
210
206
 
211
207
  class Seq2SeqForConditionalGeneration(nn.Module):
@@ -250,7 +246,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
250
246
  cache_position,
251
247
  block_tables: Optional[torch.Tensor] = None,
252
248
  ):
253
- hidden_states, self_present_key_values = self.decoder(
249
+ hidden_states = self.decoder(
254
250
  input_ids=input_ids,
255
251
  attention_mask=attention_mask,
256
252
  encoder_attention_mask=encoder_attention_mask,
@@ -265,7 +261,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
265
261
 
266
262
  lm_logits = self.lm_head(hidden_states)
267
263
 
268
- return lm_logits, self_present_key_values
264
+ return lm_logits
269
265
 
270
266
 
271
267
  class Seq2SeqDecoder(torch.nn.Module):
@@ -326,11 +322,10 @@ class Seq2SeqDecoder(torch.nn.Module):
326
322
  hidden_states = self.apply_position_embedding(hidden_states, cache_position)
327
323
 
328
324
  # iterate decoder_layer
329
- self_present_key_values = ()
330
325
  for decoder_layer, self_past_key_value, cross_past_key_value in zip(
331
326
  self.layers, self_past_key_values, cross_past_key_values
332
327
  ):
333
- hidden_states, self_present_key_value = decoder_layer(
328
+ hidden_states = decoder_layer(
334
329
  hidden_states,
335
330
  attention_mask=attention_mask,
336
331
  encoder_attention_mask=encoder_attention_mask,
@@ -339,12 +334,11 @@ class Seq2SeqDecoder(torch.nn.Module):
339
334
  cache_position=cache_position,
340
335
  block_tables=block_tables,
341
336
  )
342
- self_present_key_values += self_present_key_value
343
337
 
344
338
  if self.final_layer_norm is not None:
345
339
  hidden_states = self.final_layer_norm(hidden_states)
346
340
 
347
- return hidden_states, self_present_key_values
341
+ return hidden_states
348
342
 
349
343
 
350
344
  class Seq2SeqDecoderLayer(torch.nn.Module):
@@ -404,7 +398,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
404
398
  # Self Attention Block
405
399
  residual = hidden_states
406
400
  hidden_states = self.pre_self_attn_layer_norm(hidden_states)
407
- hidden_states, self_attn_past_key_value = self.self_attn(
401
+ hidden_states = self.self_attn(
408
402
  hidden_states=hidden_states,
409
403
  past_key_value=self_past_key_value,
410
404
  attention_mask=attention_mask,
@@ -429,7 +423,7 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
429
423
  # Feed-Forward Block
430
424
  hidden_states = self.ff_layer(hidden_states)
431
425
 
432
- return hidden_states, self_attn_past_key_value
426
+ return hidden_states
433
427
 
434
428
 
435
429
  class Seq2SeqSelfAttention(nn.Module):
@@ -492,12 +486,11 @@ class Seq2SeqSelfAttention(nn.Module):
492
486
  if attention_mask is not None:
493
487
  args.insert(3, attention_mask.unsqueeze(2))
494
488
 
495
- attn_output, key_states, value_states = self.attn_decode(*args)
489
+ attn_output = self.attn_decode(*args)
496
490
 
497
491
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
498
492
  attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
499
493
 
500
494
  attn_output = self.out_proj(attn_output)
501
- present_key_value = (key_states, value_states)
502
495
 
503
- return attn_output, present_key_value
496
+ return attn_output
@@ -88,7 +88,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
88
88
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
89
89
 
90
90
  # decode
91
- lm_logits, self_present_key_values = self.conditional_generation(
91
+ lm_logits = self.conditional_generation(
92
92
  input_ids=input_ids,
93
93
  attention_mask=attention_mask,
94
94
  encoder_attention_mask=encoder_attention_mask,
@@ -97,9 +97,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
97
97
  cache_position=cache_position,
98
98
  )
99
99
 
100
- outputs = (lm_logits,) + self_present_key_values
101
-
102
- return outputs
100
+ return lm_logits
103
101
 
104
102
 
105
103
  class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
@@ -187,7 +185,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
187
185
  key_states = self._shape(key_states, -1, bsz)
188
186
  value_states = self._shape(value_states, -1, bsz)
189
187
 
190
- attn_output, key_states, value_states = self.attn_decode(
188
+ attn_output = self.attn_decode(
191
189
  query_states,
192
190
  key_states,
193
191
  value_states,
@@ -204,9 +202,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
204
202
  attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
205
203
 
206
204
  attn_output = self.out_proj(attn_output)
207
- present_key_value = (key_states, value_states)
208
-
209
- return attn_output, present_key_value
205
+ return attn_output
210
206
 
211
207
 
212
208
  class T5CrossAttention(nn.Module):
@@ -78,9 +78,9 @@ class WhisperEncoderWrapper(torch.nn.Module):
78
78
  # 3. update cross_attention's past_key_value to the device-dram for optimization.
79
79
  bidx = torch.tensor(0, dtype=torch.int16)
80
80
  axis = torch.tensor(1, dtype=torch.int16)
81
- cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
81
+ enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
82
82
 
83
- return cross_key_values
83
+ return enc_output
84
84
 
85
85
 
86
86
  class WhisperDecoderWrapper(torch.nn.Module):
@@ -119,7 +119,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
119
119
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
120
120
 
121
121
  # Decode
122
- sequence_output, self_present_key_values, cross_attentions = self.decoder(
122
+ sequence_output, cross_attentions = self.decoder(
123
123
  input_ids=decoder_input_ids,
124
124
  attention_mask=decoder_attention_mask,
125
125
  cache_position=cache_position,
@@ -128,9 +128,7 @@ class WhisperDecoderWrapper(torch.nn.Module):
128
128
  )
129
129
 
130
130
  lm_logits = self.proj_out(sequence_output)
131
-
132
131
  outputs = (lm_logits,)
133
- outputs += self_present_key_values
134
132
 
135
133
  if self.output_attentions:
136
134
  # deocder's cross attention is used for token_timestamps
@@ -168,26 +166,23 @@ class WhisperDecoder(nn.Module):
168
166
  # prepare casual_attn_mask
169
167
  attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
170
168
 
171
- self_present_key_values = ()
172
169
  cross_attentions = ()
173
170
  # iterate decoder_layer
174
171
  for self_past_key_value, cross_past_key_value, decoder_layer in zip(
175
172
  self_past_key_values, cross_past_key_values, self.layers
176
173
  ):
177
- layer_outputs = decoder_layer(
174
+ hidden_states, cross_attn_weights = decoder_layer(
178
175
  hidden_states,
179
176
  attention_mask=attention_mask,
180
177
  self_past_key_value=self_past_key_value,
181
178
  cross_past_key_value=cross_past_key_value,
182
179
  cache_position=cache_position,
183
180
  )
184
- hidden_states = layer_outputs[0]
185
- self_present_key_values += layer_outputs[1]
186
- cross_attentions += (layer_outputs[2],)
181
+ cross_attentions += (cross_attn_weights,)
187
182
 
188
183
  hidden_states = self.layer_norm(hidden_states)
189
184
 
190
- return hidden_states, self_present_key_values, cross_attentions
185
+ return hidden_states, cross_attentions
191
186
 
192
187
 
193
188
  class WhisperDecoderLayer(nn.Module):
@@ -214,7 +209,7 @@ class WhisperDecoderLayer(nn.Module):
214
209
  # Self Attention Block
215
210
  residual = hidden_states
216
211
  hidden_states = self.self_attn_layer_norm(hidden_states)
217
- hidden_states, self_present_key_value = self.self_attn(
212
+ hidden_states = self.self_attn(
218
213
  hidden_states=hidden_states,
219
214
  past_key_value=self_past_key_value,
220
215
  attention_mask=attention_mask,
@@ -238,7 +233,7 @@ class WhisperDecoderLayer(nn.Module):
238
233
  hidden_states = self.fc2(hidden_states)
239
234
  hidden_states = residual + hidden_states
240
235
 
241
- return hidden_states, self_present_key_value, cross_attn_weights
236
+ return hidden_states, cross_attn_weights
242
237
 
243
238
 
244
239
  class WhisperAttention(nn.Module):
@@ -276,7 +271,7 @@ class WhisperSelfAttention(WhisperAttention):
276
271
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
277
272
  value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
278
273
 
279
- attn_output, key_states, value_states = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
274
+ attn_output = torch.ops.rbln_custom_ops.add_softmax_attn_decode(
280
275
  query_states,
281
276
  key_states,
282
277
  value_states,
@@ -292,7 +287,7 @@ class WhisperSelfAttention(WhisperAttention):
292
287
  attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
293
288
  attn_output = self.out_proj(attn_output)
294
289
 
295
- return attn_output, (key_states, value_states)
290
+ return attn_output
296
291
 
297
292
 
298
293
  class WhisperCrossAttention(WhisperAttention):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.3a3
3
+ Version: 0.7.3a4
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,7 +1,7 @@
1
1
  optimum/rbln/__init__.py,sha256=eHi15YM3989AcX52jka9rUmgAtlp1PHqMNwBEdOfuu8,6554
2
- optimum/rbln/__version__.py,sha256=jlkAV1bws10Tgk9b3JF90gq1GOekHphDutCCDtjNFJc,519
3
- optimum/rbln/modeling.py,sha256=3XE0IrCYbkjw9_Q9BFzZ_ri_Kyxw1g6iwfdohZB46-s,8289
4
- optimum/rbln/modeling_base.py,sha256=ELSPbjx7awBRM2SckkD-5gI3TIa01mfzz7gDRC1Pljk,21778
2
+ optimum/rbln/__version__.py,sha256=MLlg_138GxyhciEP0ZB5dPN8vriXkicRnaZiwqygxOY,519
3
+ optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
4
+ optimum/rbln/modeling_base.py,sha256=Ow73GVJF1N5cDFO8_rgirtGj1wC-cXBDyqXHW5PCybA,22270
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
6
6
  optimum/rbln/diffusers/__init__.py,sha256=pOyoXv3-JRzTBSwPKbgLS9H6F2K9dJdReEmpGhcLQYU,3283
7
7
  optimum/rbln/diffusers/modeling_diffusers.py,sha256=zqVNgH9oeOx2iNE7VsW_FinVf4s6G5Idyh4TKz7XJJg,21116
@@ -40,9 +40,9 @@ optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_x
40
40
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=3aB1Rw-OgKytQOHwOaShbEvq_XVHPOGvsGm8pstEmKU,930
41
41
  optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=MzVP1wscaO1sUIiBIPJqG6zuGyez9VUbA42-JSIm-mk,930
42
42
  optimum/rbln/ops/__init__.py,sha256=TxOmsN0u3PmyK4Sb89qbiC4rePOlkvUT7Lm6wVoTnY0,941
43
- optimum/rbln/ops/attn.py,sha256=LbJAmFtNj05i6BURfKV3KybsPItFe8w-YdSe5SuWkEc,12365
44
- optimum/rbln/ops/flash_attn.py,sha256=4shKNY13skPoYnbEsGrXDzgNwBIhHZEFrnUnWx1ESZU,4076
45
- optimum/rbln/ops/kv_cache_update.py,sha256=9W4WCO1Dtfy0u5i978JJRa7uLbqrfR2lHuoPynb07fw,3143
43
+ optimum/rbln/ops/attn.py,sha256=3EqU63Oj4zI4rLbkRycorsscXeD-IpKzt9N1MhkMa5o,10374
44
+ optimum/rbln/ops/flash_attn.py,sha256=wfyiCxDGf034IngzwRU160R7_DlKYpd-uWT0BDEGFks,3408
45
+ optimum/rbln/ops/kv_cache_update.py,sha256=pxf8kAptPaQF5xE8qItvmlFOq_sgim6ZERD7AVaOtec,3221
46
46
  optimum/rbln/transformers/__init__.py,sha256=AGo3BqVIZrsOzYsQAnnQ25HCstTPBclrXbvvUxVMlqE,4255
47
47
  optimum/rbln/transformers/modeling_alias.py,sha256=yx7FnZQWAnrWzivaO5hI7T6i-fyLzt2tMIXG2oDNbPo,1657
48
48
  optimum/rbln/transformers/modeling_generic.py,sha256=aaZWsqVDCRvH03q-Wen7DMfLr7Gy-u-I0mTw0aYqWjk,18195
@@ -59,8 +59,8 @@ optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaM
59
59
  optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
60
60
  optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
61
61
  optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
62
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=x8_xQ5aGXbadJyajpJQyfgxx4YPSj62VlmmGDMnC-1E,41819
63
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=DylKxV1kFbDv34txpuI5JrvMcSTa2W910eO9dmF0o_8,35352
62
+ optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=7OIKteJLKNxOLOg0w3lLOM7TxZovQn4jkglI9wRkrtQ,40609
63
+ optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=W9HnxJoTz78Wc4X5Q3sMSHhMTSa7-9uQCFlnqNVozvA,38932
64
64
  optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
65
65
  optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
66
66
  optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
@@ -91,16 +91,16 @@ optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz
91
91
  optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
92
92
  optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
93
93
  optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=NPfJf9Uk_bYOae7hXGHwteGiWH0va63Z-D93RmAMENg,17611
94
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=QXIGWSu9PsKWE3WhkgmBj3VeszqXIo2MPOwcrb54Tbs,19348
94
+ optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
95
95
  optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
96
96
  optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=nKRR3eH1EAu1YkKvhlqGyTrJXDRd-IWB5LOeG9jrcb4,16021
97
- optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=oCdmF4eCTayAVjx3c-SVpmhrjnWE92jh79dMIYCwotY,9690
97
+ optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=AArCQhZRETVM583wlIRzMFOSYq7t2nzxaAeyhZxyxKk,9508
98
98
  optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
99
99
  optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
100
100
  optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
101
101
  optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
102
102
  optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=0nBADNxE0A1ozBbRutTBvxpo_Y1qkOycT_zronkN-ZU,15840
103
- optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=_6PmE4-DD5QhohQwHW5M11q_L9f_ayF6StmNTlOYJdg,12896
103
+ optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=Yn6yFpmw6IQbWlnpIMAdEUsNF6huXgaKzGMUZbhSLdo,12572
104
104
  optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
105
105
  optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=8YNLz0bc5ze-QuU8rN-QhUfGzlSUs3iMJiWTxO3o6AM,4366
106
106
  optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -114,7 +114,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
114
114
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
115
115
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
116
116
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
117
- optimum_rbln-0.7.3a3.dist-info/METADATA,sha256=UQs6c3GdXbPYE8wSnT6Ca9TtgfKwEgPNVZk-MoAKQPc,5300
118
- optimum_rbln-0.7.3a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
119
- optimum_rbln-0.7.3a3.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
120
- optimum_rbln-0.7.3a3.dist-info/RECORD,,
117
+ optimum_rbln-0.7.3a4.dist-info/METADATA,sha256=8VNTOVgsgFtcFUuZ9VEeRQfC2LEB60OFmW92hlJo8V8,5300
118
+ optimum_rbln-0.7.3a4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
119
+ optimum_rbln-0.7.3a4.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
120
+ optimum_rbln-0.7.3a4.dist-info/RECORD,,