optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
optimum/rbln/__init__.py CHANGED
@@ -73,6 +73,7 @@ _import_structure = {
73
73
  "RBLNRobertaForMaskedLM",
74
74
  "RBLNViTForImageClassification",
75
75
  "RBLNBertForMaskedLM",
76
+ "RBLNTimeSeriesTransformerForPrediction",
76
77
  ],
77
78
  "diffusers": [
78
79
  "RBLNAutoencoderKL",
@@ -184,6 +185,7 @@ if TYPE_CHECKING:
184
185
  RBLNRobertaForSequenceClassification,
185
186
  RBLNT5EncoderModel,
186
187
  RBLNT5ForConditionalGeneration,
188
+ RBLNTimeSeriesTransformerForPrediction,
187
189
  RBLNViTForImageClassification,
188
190
  RBLNWav2Vec2ForCTC,
189
191
  RBLNWhisperForConditionalGeneration,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.3.post2'
21
- __version_tuple__ = version_tuple = (0, 7, 3)
20
+ __version__ = version = '0.7.4a1'
21
+ __version_tuple__ = version_tuple = (0, 7, 4)
@@ -13,9 +13,10 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .attn import (
16
- register_rbln_custom_add_softmax_attention,
16
+ register_rbln_custom_paged_add_softmax_attention,
17
17
  register_rbln_custom_paged_attention,
18
18
  register_rbln_custom_paged_causal_attention,
19
19
  )
20
20
  from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
21
21
  from .kv_cache_update import register_rbln_custom_cache_update
22
+ from .linear import linear
optimum/rbln/ops/attn.py CHANGED
@@ -182,14 +182,14 @@ def register_rbln_custom_paged_causal_attention():
182
182
 
183
183
 
184
184
  @lru_cache
185
- def register_rbln_custom_add_softmax_attention():
185
+ def register_rbln_custom_paged_add_softmax_attention():
186
186
  torch.library.define(
187
- "rbln_custom_ops::add_softmax_attn_decode",
188
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor",
187
+ "rbln_custom_ops::paged_add_softmax_attn_decode",
188
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor",
189
189
  )
190
190
 
191
- @torch.library.impl("rbln_custom_ops::add_softmax_attn_decode", "cpu")
192
- def add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
191
+ @torch.library.impl("rbln_custom_ops::paged_add_softmax_attn_decode", "cpu")
192
+ def paged_add_softmax_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size):
193
193
  """Defines the computation pattern for fused attention with KV cache updates.
194
194
 
195
195
  IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
@@ -210,12 +210,14 @@ def register_rbln_custom_add_softmax_attention():
210
210
  - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
211
211
  - seq: [1] - Current sequence position
212
212
  - scale: [] - Attention scale factor
213
+ - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
214
+ - block_size: [] - Number of tokens per block
213
215
 
214
216
  Returns:
215
217
  Tensor: attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
216
218
  """
217
219
  return q
218
220
 
219
- @register_fake("rbln_custom_ops::add_softmax_attn_decode")
220
- def add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
221
+ @register_fake("rbln_custom_ops::paged_add_softmax_attn_decode")
222
+ def paged_add_softmax_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition, block_table, block_size):
221
223
  return q
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op("rbln_custom_ops::linear", mutates_args=())
22
+ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
23
+ output_shape = list(input.shape[:-1])
24
+ output_shape += [weight.shape[0]]
25
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
@@ -52,6 +52,7 @@ _import_structure = {
52
52
  "RBLNPhiForCausalLM",
53
53
  "RBLNT5EncoderModel",
54
54
  "RBLNT5ForConditionalGeneration",
55
+ "RBLNTimeSeriesTransformerForPrediction",
55
56
  "RBLNLlavaNextForConditionalGeneration",
56
57
  "RBLNMidmLMHeadModel",
57
58
  "RBLNXLMRobertaModel",
@@ -113,6 +114,7 @@ if TYPE_CHECKING:
113
114
  RBLNQwen2ForCausalLM,
114
115
  RBLNT5EncoderModel,
115
116
  RBLNT5ForConditionalGeneration,
117
+ RBLNTimeSeriesTransformerForPrediction,
116
118
  RBLNWav2Vec2ForCTC,
117
119
  RBLNWhisperForConditionalGeneration,
118
120
  RBLNXLMRobertaModel,
@@ -50,6 +50,7 @@ _import_structure = {
50
50
  "mistral": ["RBLNMistralForCausalLM"],
51
51
  "phi": ["RBLNPhiForCausalLM"],
52
52
  "qwen2": ["RBLNQwen2ForCausalLM"],
53
+ "time_series_transformers": ["RBLNTimeSeriesTransformerForPrediction"],
53
54
  "t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
54
55
  "wav2vec2": ["RBLNWav2Vec2ForCTC"],
55
56
  "whisper": ["RBLNWhisperForConditionalGeneration"],
@@ -90,6 +91,7 @@ if TYPE_CHECKING:
90
91
  from .phi import RBLNPhiForCausalLM
91
92
  from .qwen2 import RBLNQwen2ForCausalLM
92
93
  from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
94
+ from .time_series_transformers import RBLNTimeSeriesTransformerForPrediction
93
95
  from .wav2vec2 import RBLNWav2Vec2ForCTC
94
96
  from .whisper import RBLNWhisperForConditionalGeneration
95
97
  from .xlm_roberta import RBLNXLMRobertaModel
@@ -94,12 +94,11 @@ class RBLNBartModel(RBLNModel):
94
94
  for model_input_name in rbln_model_input_names
95
95
  ]
96
96
 
97
- enc_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="encoder")
98
- dec_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="decoder")
97
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
99
98
 
100
99
  rbln_config = RBLNConfig(
101
100
  rbln_cls=cls.__name__,
102
- compile_cfgs=[enc_compile_config, dec_compile_config],
101
+ compile_cfgs=[rbln_compile_config],
103
102
  rbln_kwargs=rbln_kwargs,
104
103
  )
105
104
 
@@ -108,6 +107,8 @@ class RBLNBartModel(RBLNModel):
108
107
 
109
108
 
110
109
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
110
+ support_causal_attn = True
111
+
111
112
  @classmethod
112
113
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
113
114
  enc_max_seq_len = (
@@ -222,8 +222,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
222
222
 
223
223
  attention_mask = self.dec_attn_mask
224
224
 
225
- attention_mask = self.dec_attn_mask
226
-
227
225
  logits = super().forward(
228
226
  inputs,
229
227
  cache_position,
@@ -547,22 +545,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
547
545
 
548
546
  @QuantizationManager.with_quantization_env
549
547
  def compile_model(*args, **kwargs):
550
- wrapped_model.phase = "prefill"
551
- compiled_prefill = RBLNModel.compile(
552
- wrapped_model,
553
- prefill_compile_config,
554
- example_inputs=prefill_example_inputs,
555
- compile_context=context,
556
- )
548
+ try:
549
+ original_linear = torch.nn.functional.linear
550
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
551
+ wrapped_model.phase = "prefill"
552
+ compiled_prefill = RBLNModel.compile(
553
+ wrapped_model,
554
+ prefill_compile_config,
555
+ example_inputs=prefill_example_inputs,
556
+ compile_context=context,
557
+ )
557
558
 
558
- wrapped_model.phase = "decode"
559
- compiled_decoder = RBLNModel.compile(
560
- wrapped_model,
561
- dec_compile_config,
562
- example_inputs=dec_example_inputs,
563
- compile_context=context,
564
- )
565
- return {"prefill": compiled_prefill, "decoder": compiled_decoder}
559
+ wrapped_model.phase = "decode"
560
+ compiled_decoder = RBLNModel.compile(
561
+ wrapped_model,
562
+ dec_compile_config,
563
+ example_inputs=dec_example_inputs,
564
+ compile_context=context,
565
+ )
566
+ return {"prefill": compiled_prefill, "decoder": compiled_decoder}
567
+ finally:
568
+ torch.nn.functional.linear = original_linear
566
569
 
567
570
  return compile_model(quantize_config=quantize_config)
568
571
 
@@ -38,8 +38,8 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
38
38
  mandatory_members = ["main_input_name"]
39
39
 
40
40
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
41
- _ = super().forward(*args, **kwargs)
42
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
41
+ output = super().forward(*args, **kwargs)
42
+ return BaseModelOutput(last_hidden_state=output)
43
43
 
44
44
 
45
45
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
@@ -94,7 +94,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
94
94
  decoder_attention_mask if self.use_attention_mask else None,
95
95
  attention_mask,
96
96
  cache_position,
97
- block_tables,
97
+ block_tables=block_tables,
98
98
  )
99
99
 
100
100
  return Seq2SeqLMOutput(logits=lm_logits)
@@ -115,6 +115,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
115
115
 
116
116
  main_input_name = "input_ids"
117
117
  auto_model_class = AutoModelForSeq2SeqLM
118
+ support_causal_attn = None
118
119
 
119
120
  def __post_init__(self, **kwargs):
120
121
  batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -186,13 +187,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
186
187
  rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
187
188
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
188
189
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
189
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
190
190
 
191
- if rbln_use_attention_mask is None:
192
- rbln_use_attention_mask = False
193
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
194
- if rbln_npu == "RBLN-CA02":
195
- rbln_use_attention_mask = True
191
+ if cls.support_causal_attn:
192
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
193
+ if rbln_use_attention_mask is None:
194
+ rbln_use_attention_mask = False
195
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
196
+ if rbln_npu == "RBLN-CA02":
197
+ rbln_use_attention_mask = True
198
+ else:
199
+ rbln_use_attention_mask = True
196
200
 
197
201
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
198
202
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -265,11 +269,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
265
269
  [rbln_batch_size, 1],
266
270
  "int32",
267
271
  ),
268
- (
269
- "block_tables",
270
- [rbln_batch_size, 1],
271
- "int16",
272
- ),
272
+ ("block_tables", [rbln_batch_size, 1], "int16"),
273
273
  ]
274
274
  dec_input_info.extend(
275
275
  [
@@ -13,9 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
17
17
 
18
- import rebel
19
18
  import torch
20
19
  from transformers import (
21
20
  AutoModelForTextEncoding,
@@ -23,7 +22,7 @@ from transformers import (
23
22
  T5EncoderModel,
24
23
  T5ForConditionalGeneration,
25
24
  )
26
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
25
+ from transformers.modeling_outputs import BaseModelOutput
27
26
 
28
27
  from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
29
28
  from ....modeling import RBLNModel
@@ -58,63 +57,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
58
57
  )
59
58
 
60
59
 
61
- class RBLNRuntimeEncoder(RBLNPytorchRuntime):
62
- mandatory_members = ["main_input_name"]
63
-
64
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
65
- _ = super().forward(*args, **kwargs)
66
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
67
-
68
-
69
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
70
- mandatory_members = ["main_input_name"]
71
-
72
- def __init__(
73
- self,
74
- runtime: rebel.Runtime,
75
- batch_size: int,
76
- dec_max_seq_len: int,
77
- **kwargs: Any,
78
- ) -> None:
79
- super().__init__(runtime, **kwargs)
80
- self.batch_size = batch_size
81
- self.dec_max_seq_len = dec_max_seq_len
82
-
83
- def forward(
84
- self,
85
- decoder_input_ids: Optional[torch.LongTensor] = None,
86
- attention_mask: Optional[torch.FloatTensor] = None,
87
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
88
- cache_position: Optional[torch.Tensor] = None,
89
- **kwargs,
90
- ) -> Tuple[torch.FloatTensor]:
91
- batch_size = decoder_input_ids.shape[0]
92
- if batch_size != self.batch_size:
93
- raise RuntimeError(
94
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
95
- )
96
-
97
- if batch_size != cache_position.shape[0]:
98
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
99
-
100
- for b_idx in range(self.batch_size):
101
- decoding_step = cache_position[b_idx].item()
102
- if not (0 <= decoding_step < self.dec_max_seq_len):
103
- raise ValueError(
104
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
105
- )
106
- decoder_attention_mask[b_idx, : decoding_step + 1] = 1
107
-
108
- lm_logits = super().forward(
109
- decoder_input_ids,
110
- decoder_attention_mask,
111
- attention_mask,
112
- cache_position,
113
- )
114
-
115
- return Seq2SeqLMOutput(logits=lm_logits)
116
-
117
-
118
60
  class T5EncoderWrapper(torch.nn.Module):
119
61
  def __init__(self, model: "T5EncoderModel") -> None:
120
62
  super().__init__()
@@ -247,20 +189,7 @@ class RBLNT5EncoderModel(RBLNModel):
247
189
 
248
190
 
249
191
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
250
- def __post_init__(self, **kwargs):
251
- batch_size = self.rbln_config.model_cfg["batch_size"]
252
- dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
253
-
254
- self.encoder = RBLNRuntimeEncoder(
255
- runtime=self.model[0],
256
- main_input_name="input_ids",
257
- )
258
- self.decoder = RBLNRuntimeDecoder(
259
- runtime=self.model[1],
260
- main_input_name="input_ids",
261
- batch_size=batch_size,
262
- dec_max_seq_len=dec_max_seq_len,
263
- )
192
+ support_causal_attn = False
264
193
 
265
194
  @classmethod
266
195
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
@@ -279,139 +208,3 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
279
208
  return redirect(val)
280
209
 
281
210
  return val
282
-
283
- @classmethod
284
- def _get_rbln_config(
285
- cls,
286
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
287
- model_config: "PretrainedConfig",
288
- rbln_kwargs: Dict[str, Any] = {},
289
- ) -> RBLNConfig:
290
- rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
291
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
292
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
293
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
294
-
295
- n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
296
- n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
297
- d_kv = (
298
- model_config.d_kv
299
- if hasattr(model_config, "d_kv")
300
- else model_config.d_model // model_config.encoder_attention_heads
301
- )
302
-
303
- max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
304
- model_config, "max_position_embeddings", None
305
- )
306
-
307
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
308
- if rbln_pad_token_id is None:
309
- rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
310
- if rbln_pad_token_id is None:
311
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
312
- if rbln_pad_token_id is None:
313
- rbln_pad_token_id = -1
314
-
315
- if rbln_enc_max_seq_len is None:
316
- rbln_enc_max_seq_len = max_position_embeddings
317
- if rbln_enc_max_seq_len is None:
318
- for tokenizer in preprocessors:
319
- if hasattr(tokenizer, "model_max_length"):
320
- rbln_enc_max_seq_len = tokenizer.model_max_length
321
- break
322
- if rbln_enc_max_seq_len is None:
323
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
324
- if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
325
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
326
-
327
- if rbln_dec_max_seq_len is None:
328
- rbln_dec_max_seq_len = max_position_embeddings
329
- if rbln_dec_max_seq_len is None:
330
- for tokenizer in preprocessors:
331
- if hasattr(tokenizer, "model_max_length"):
332
- rbln_dec_max_seq_len = tokenizer.model_max_length
333
- break
334
- if rbln_dec_max_seq_len is None:
335
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
336
-
337
- if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
338
- raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
339
-
340
- # model input info
341
- enc_input_info = [
342
- ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
343
- ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
344
- (
345
- "cross_key_value_states",
346
- [
347
- n_layer * 2,
348
- rbln_batch_size,
349
- n_head,
350
- rbln_enc_max_seq_len,
351
- d_kv,
352
- ],
353
- "float32",
354
- ),
355
- ("block_tables", [1], "int16"),
356
- ]
357
-
358
- dec_input_info = [
359
- ("input_ids", [rbln_batch_size, 1], "int64"),
360
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
361
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
362
- (
363
- "cache_position",
364
- [rbln_batch_size, 1],
365
- "int32",
366
- ),
367
- ]
368
- dec_input_info.extend(
369
- [
370
- (
371
- "cross_key_value_states",
372
- [
373
- n_layer * 2,
374
- rbln_batch_size,
375
- n_head,
376
- rbln_enc_max_seq_len,
377
- d_kv,
378
- ],
379
- "float32",
380
- )
381
- ]
382
- )
383
- dec_input_info.extend(
384
- [
385
- (
386
- f"self_key_value_states_{i}",
387
- [
388
- rbln_batch_size,
389
- n_head,
390
- rbln_dec_max_seq_len,
391
- d_kv,
392
- ],
393
- "float32",
394
- )
395
- for i in range(n_layer * 2)
396
- ]
397
- )
398
-
399
- enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
400
- dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
401
-
402
- rbln_config = RBLNConfig(
403
- rbln_cls=cls.__name__,
404
- compile_cfgs=[enc_compile_config, dec_compile_config],
405
- rbln_kwargs=rbln_kwargs,
406
- )
407
-
408
- rbln_config.model_cfg.update(
409
- {
410
- "enc_max_seq_len": rbln_enc_max_seq_len,
411
- "dec_max_seq_len": rbln_dec_max_seq_len,
412
- "batch_size": rbln_batch_size,
413
- "pad_token_id": rbln_pad_token_id,
414
- }
415
- )
416
-
417
- return rbln_config
@@ -18,7 +18,7 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_add_softmax_attention
21
+ from ....ops import register_rbln_custom_paged_add_softmax_attention
22
22
  from ..seq2seq.seq2seq_architecture import (
23
23
  Seq2SeqDecoder,
24
24
  Seq2SeqDecoderLayer,
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
55
55
 
56
56
  class T5DecoderWrapper(Seq2SeqDecoderWrapper):
57
57
  def __post_init__(self, model, dec_max_seq_len: int = None):
58
- register_rbln_custom_add_softmax_attention()
58
+ register_rbln_custom_paged_add_softmax_attention()
59
59
  self.num_layers = self.config.num_layers
60
60
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
61
61
 
@@ -77,6 +77,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
77
77
  attention_mask,
78
78
  encoder_attention_mask,
79
79
  cache_position,
80
+ block_tables,
80
81
  cross_kv_cache,
81
82
  *self_kv_cache,
82
83
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
@@ -95,6 +96,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
95
96
  self_past_key_values=self_past_key_values,
96
97
  cross_past_key_values=cross_past_key_values,
97
98
  cache_position=cache_position,
99
+ block_tables=block_tables,
98
100
  )
99
101
 
100
102
  return lm_logits
@@ -162,7 +164,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
162
164
  self.out_proj = self._original_mod.o
163
165
  self.num_heads = self._original_mod.n_heads
164
166
  self.head_dim = self._original_mod.key_value_proj_dim
165
- self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
167
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
166
168
 
167
169
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
168
170
  query_states = self.q_proj(hidden_states)
@@ -176,6 +178,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
176
178
  past_key_value: Tuple[torch.Tensor],
177
179
  attention_mask: torch.Tensor,
178
180
  cache_position: torch.Tensor,
181
+ block_tables: torch.Tensor,
179
182
  **kwargs,
180
183
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
181
184
  bsz, tgt_len, _ = hidden_states.size()
@@ -185,6 +188,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
185
188
  key_states = self._shape(key_states, -1, bsz)
186
189
  value_states = self._shape(value_states, -1, bsz)
187
190
 
191
+ block_size = past_key_value[0].shape[-2]
188
192
  attn_output = self.attn_decode(
189
193
  query_states,
190
194
  key_states,
@@ -196,6 +200,8 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
196
200
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
197
201
  cache_position,
198
202
  torch.tensor(1.0, dtype=torch.float32), # scale
203
+ block_tables,
204
+ block_size,
199
205
  )
200
206
 
201
207
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction