optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -18,12 +18,6 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import (
22
- register_rbln_custom_cache_update,
23
- register_rbln_custom_paged_attention,
24
- register_rbln_custom_paged_causal_attention,
25
- )
26
-
27
21
 
28
22
  logger = logging.get_logger(__name__)
29
23
 
@@ -59,7 +53,6 @@ class Seq2SeqEncoderWrapper(nn.Module):
59
53
 
60
54
  def __init__(self, model: nn.Module, enc_max_seq_len: int):
61
55
  super().__init__()
62
- register_rbln_custom_cache_update()
63
56
  self.config = model.config
64
57
  self.encoder = model.get_encoder()
65
58
  self.encoder_max_length = enc_max_seq_len
@@ -90,8 +83,8 @@ class Seq2SeqEncoderWrapper(nn.Module):
90
83
  self,
91
84
  input_ids: torch.Tensor,
92
85
  attention_mask: torch.Tensor,
93
- cross_key_values: torch.Tensor,
94
86
  b_idx: torch.Tensor,
87
+ *cross_key_values: Tuple[torch.Tensor],
95
88
  ) -> Tuple[torch.Tensor]:
96
89
  # 1. get encoder last_hidden_states
97
90
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
@@ -110,13 +103,15 @@ class Seq2SeqEncoderWrapper(nn.Module):
110
103
  cross_kv.append(past_k)
111
104
  cross_kv.append(past_v)
112
105
 
113
- cross_kv = torch.stack(cross_kv, dim=0)
114
-
115
106
  # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
116
- batch_axis = torch.tensor(1, dtype=torch.int16)
117
- enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
107
+ batch_axis = torch.tensor(0, dtype=torch.int16)
108
+ cross_key_values = list(cross_key_values)
109
+ for i in range(self.n_layer * 2):
110
+ cross_key_values[i] = torch.ops.rbln_custom_ops.rbln_cache_update(
111
+ cross_key_values[i], cross_kv[i], b_idx[0], batch_axis
112
+ )
118
113
 
119
- return enc_out
114
+ return cross_key_values
120
115
 
121
116
 
122
117
  class Seq2SeqDecoderWrapper(nn.Module):
@@ -146,11 +141,6 @@ class Seq2SeqDecoderWrapper(nn.Module):
146
141
  It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
147
142
  by subclasses to modify or add custom attributes as necessary.
148
143
  """
149
- if self.use_attention_mask:
150
- register_rbln_custom_paged_attention()
151
- else:
152
- register_rbln_custom_paged_causal_attention()
153
-
154
144
  self.num_layers = self.config.decoder_layers
155
145
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
156
146
 
@@ -176,16 +166,17 @@ class Seq2SeqDecoderWrapper(nn.Module):
176
166
  encoder_attention_mask,
177
167
  cache_position,
178
168
  block_tables,
179
- cross_kv_cache,
180
- *self_kv_cache,
169
+ *kv_cache,
181
170
  ) = args
182
171
 
183
172
  else:
184
173
  attention_mask = None
185
- (input_ids, encoder_attention_mask, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
174
+ (input_ids, encoder_attention_mask, cache_position, block_tables, *kv_cache) = args
186
175
 
187
176
  self_past_key_values = ()
188
177
  cross_past_key_values = ()
178
+ self_kv_cache = kv_cache[self.num_layers * 2 :]
179
+ cross_kv_cache = kv_cache[: self.num_layers * 2]
189
180
  for i in range(0, self.num_layers * 2, 2):
190
181
  self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
191
182
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
@@ -12,4 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import paged_add_softmax_attn_decode
16
+ from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
15
17
  from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ pass
21
+
22
+
23
+ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
+ pass
@@ -13,106 +13,21 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable
17
17
 
18
- import rebel
19
18
  import torch
20
- from transformers import (
21
- AutoModelForTextEncoding,
22
- PretrainedConfig,
23
- T5EncoderModel,
24
- T5ForConditionalGeneration,
25
- )
26
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
19
+ from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
27
20
 
28
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
29
- from ....modeling import RBLNModel
30
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
- from ....utils.logging import get_logger
32
- from ....utils.runtime_utils import RBLNPytorchRuntime
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
33
22
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
23
+ from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
34
24
  from .t5_architecture import T5Wrapper
35
25
 
36
26
 
37
- logger = get_logger()
38
-
39
27
  if TYPE_CHECKING:
40
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
41
-
42
-
43
- class RBLNRuntimeModel(RBLNPytorchRuntime):
44
- def forward(
45
- self,
46
- input_ids: torch.LongTensor,
47
- attention_mask: torch.FloatTensor,
48
- head_mask: torch.FloatTensor,
49
- inputs_embeds: torch.FloatTensor,
50
- **kwargs,
51
- ):
52
- return super().forward(
53
- input_ids,
54
- attention_mask,
55
- head_mask,
56
- inputs_embeds,
57
- **kwargs,
58
- )
59
-
60
-
61
- class RBLNRuntimeEncoder(RBLNPytorchRuntime):
62
- mandatory_members = ["main_input_name"]
63
-
64
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
65
- _ = super().forward(*args, **kwargs)
66
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
67
-
28
+ from transformers import PreTrainedModel
68
29
 
69
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
70
- mandatory_members = ["main_input_name"]
71
-
72
- def __init__(
73
- self,
74
- runtime: rebel.Runtime,
75
- batch_size: int,
76
- dec_max_seq_len: int,
77
- **kwargs: Any,
78
- ) -> None:
79
- super().__init__(runtime, **kwargs)
80
- self.batch_size = batch_size
81
- self.dec_max_seq_len = dec_max_seq_len
82
-
83
- def forward(
84
- self,
85
- decoder_input_ids: Optional[torch.LongTensor] = None,
86
- attention_mask: Optional[torch.FloatTensor] = None,
87
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
88
- cache_position: Optional[torch.Tensor] = None,
89
- **kwargs,
90
- ) -> Tuple[torch.FloatTensor]:
91
- batch_size = decoder_input_ids.shape[0]
92
- if batch_size != self.batch_size:
93
- raise RuntimeError(
94
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
95
- )
96
-
97
- if batch_size != cache_position.shape[0]:
98
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
99
-
100
- for b_idx in range(self.batch_size):
101
- decoding_step = cache_position[b_idx].item()
102
- if not (0 <= decoding_step < self.dec_max_seq_len):
103
- raise ValueError(
104
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
105
- )
106
- decoder_attention_mask[b_idx, : decoding_step + 1] = 1
107
-
108
- lm_logits = super().forward(
109
- decoder_input_ids,
110
- decoder_attention_mask,
111
- attention_mask,
112
- cache_position,
113
- )
114
-
115
- return Seq2SeqLMOutput(logits=lm_logits)
30
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
116
31
 
117
32
 
118
33
  class T5EncoderWrapper(torch.nn.Module):
@@ -125,149 +40,35 @@ class T5EncoderWrapper(torch.nn.Module):
125
40
  return self.model(*args, **kwargs, return_dict=False)
126
41
 
127
42
 
128
- class RBLNT5EncoderModel(RBLNModel):
43
+ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
129
44
  auto_model_class = AutoModelForTextEncoding
130
45
  rbln_model_input_names = ["input_ids", "attention_mask"]
131
46
 
132
- def __post_init__(self, **kwargs):
133
- self.model = RBLNRuntimeModel(runtime=self.model[0])
134
-
135
47
  @classmethod
136
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
48
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
137
49
  return T5EncoderWrapper(model)
138
50
 
139
51
  @classmethod
140
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
141
- batch_size = rbln_config.get("batch_size", 1)
142
- max_sequence_length = rbln_config.get("max_sequence_length", 256)
143
- model_input_names = ["input_ids"]
144
-
145
- rbln_config.update(
146
- {
147
- "batch_size": batch_size,
148
- "max_seq_len": max_sequence_length,
149
- "model_input_names": model_input_names,
150
- }
151
- )
152
-
153
- return rbln_config
154
-
155
- @classmethod
156
- def _get_rbln_config(
52
+ def update_rbln_config_using_pipe(
157
53
  cls,
158
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
159
- model_config: Optional["PretrainedConfig"] = None,
160
- rbln_kwargs: Dict[str, Any] = {},
161
- ) -> RBLNConfig:
162
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
163
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
164
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
165
-
166
- max_position_embeddings = getattr(model_config, "n_positions", None)
167
-
168
- if rbln_max_seq_len is None:
169
- rbln_max_seq_len = max_position_embeddings
170
- if rbln_max_seq_len is None:
171
- for tokenizer in preprocessors:
172
- if hasattr(tokenizer, "model_max_length"):
173
- rbln_max_seq_len = tokenizer.model_max_length
174
- break
175
- if rbln_max_seq_len is None:
176
- raise ValueError("`rbln_max_seq_len` should be specified!")
177
-
178
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
179
- raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
180
-
181
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
182
-
183
- if rbln_model_input_names is None:
184
- for tokenizer in preprocessors:
185
- if hasattr(tokenizer, "model_input_names"):
186
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
187
-
188
- invalid_params = set(rbln_model_input_names) - set(signature_params)
189
- if invalid_params:
190
- raise ValueError(f"Invalid model input names: {invalid_params}")
191
- break
192
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
193
- rbln_model_input_names = cls.rbln_model_input_names
194
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
195
- raise ValueError(
196
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
197
- f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(signature_params)})"
198
- )
199
- else:
200
- invalid_params = set(rbln_model_input_names) - set(signature_params)
201
- if invalid_params:
202
- raise ValueError(f"Invalid model input names: {invalid_params}")
203
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
204
-
205
- if rbln_batch_size is None:
206
- rbln_batch_size = 1
207
-
208
- input_info = [
209
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
210
- for model_input_name in rbln_model_input_names
211
- ]
212
-
213
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
214
-
215
- rbln_config = RBLNConfig(
216
- rbln_cls=cls.__name__,
217
- compile_cfgs=[rbln_compile_config],
218
- rbln_kwargs=rbln_kwargs,
219
- )
220
-
221
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
54
+ pipe: "RBLNDiffusionMixin",
55
+ rbln_config: "RBLNDiffusionMixinConfig",
56
+ submodule_name: str,
57
+ ) -> "RBLNDiffusionMixinConfig":
58
+ submodule_config = getattr(rbln_config, submodule_name)
59
+ submodule_config.max_seq_len = rbln_config.max_seq_len or 256
60
+ submodule_config.model_input_names = ["input_ids"]
222
61
  return rbln_config
223
62
 
224
- def forward(
225
- self,
226
- input_ids: Optional[torch.LongTensor] = None,
227
- attention_mask: Optional[torch.FloatTensor] = None,
228
- head_mask: Optional[torch.FloatTensor] = None,
229
- inputs_embeds: Optional[torch.FloatTensor] = None,
230
- output_attentions: Optional[bool] = None,
231
- output_hidden_states: Optional[bool] = None,
232
- return_dict: Optional[bool] = None,
233
- ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
234
- encoder_outputs = self.model(
235
- input_ids=input_ids,
236
- attention_mask=attention_mask,
237
- inputs_embeds=inputs_embeds,
238
- head_mask=head_mask,
239
- output_attentions=output_attentions,
240
- output_hidden_states=output_hidden_states,
241
- return_dict=return_dict,
242
- )
243
- if not return_dict:
244
- return (encoder_outputs,)
245
- else:
246
- return BaseModelOutput(last_hidden_state=encoder_outputs)
247
-
248
63
 
249
64
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
250
- def __post_init__(self, **kwargs):
251
- batch_size = self.rbln_config.model_cfg["batch_size"]
252
- dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
253
-
254
- self.encoder = RBLNRuntimeEncoder(
255
- runtime=self.model[0],
256
- main_input_name="input_ids",
257
- )
258
- self.decoder = RBLNRuntimeDecoder(
259
- runtime=self.model[1],
260
- main_input_name="input_ids",
261
- batch_size=batch_size,
262
- dec_max_seq_len=dec_max_seq_len,
263
- )
65
+ support_causal_attn = False
264
66
 
265
67
  @classmethod
266
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
267
- enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
268
- dec_max_seq_len = rbln_config.model_cfg["dec_max_seq_len"]
269
-
270
- return T5Wrapper(model, enc_max_seq_len=enc_max_seq_len, dec_max_seq_len=dec_max_seq_len)
68
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
69
+ return T5Wrapper(
70
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
71
+ )
271
72
 
272
73
  def __getattr__(self, __name: str) -> Any:
273
74
  def redirect(func):
@@ -279,139 +80,3 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
279
80
  return redirect(val)
280
81
 
281
82
  return val
282
-
283
- @classmethod
284
- def _get_rbln_config(
285
- cls,
286
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
287
- model_config: "PretrainedConfig",
288
- rbln_kwargs: Dict[str, Any] = {},
289
- ) -> RBLNConfig:
290
- rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
291
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
292
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
293
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
294
-
295
- n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
296
- n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
297
- d_kv = (
298
- model_config.d_kv
299
- if hasattr(model_config, "d_kv")
300
- else model_config.d_model // model_config.encoder_attention_heads
301
- )
302
-
303
- max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
304
- model_config, "max_position_embeddings", None
305
- )
306
-
307
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
308
- if rbln_pad_token_id is None:
309
- rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
310
- if rbln_pad_token_id is None:
311
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
312
- if rbln_pad_token_id is None:
313
- rbln_pad_token_id = -1
314
-
315
- if rbln_enc_max_seq_len is None:
316
- rbln_enc_max_seq_len = max_position_embeddings
317
- if rbln_enc_max_seq_len is None:
318
- for tokenizer in preprocessors:
319
- if hasattr(tokenizer, "model_max_length"):
320
- rbln_enc_max_seq_len = tokenizer.model_max_length
321
- break
322
- if rbln_enc_max_seq_len is None:
323
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
324
- if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
325
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
326
-
327
- if rbln_dec_max_seq_len is None:
328
- rbln_dec_max_seq_len = max_position_embeddings
329
- if rbln_dec_max_seq_len is None:
330
- for tokenizer in preprocessors:
331
- if hasattr(tokenizer, "model_max_length"):
332
- rbln_dec_max_seq_len = tokenizer.model_max_length
333
- break
334
- if rbln_dec_max_seq_len is None:
335
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
336
-
337
- if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
338
- raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
339
-
340
- # model input info
341
- enc_input_info = [
342
- ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
343
- ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
344
- (
345
- "cross_key_value_states",
346
- [
347
- n_layer * 2,
348
- rbln_batch_size,
349
- n_head,
350
- rbln_enc_max_seq_len,
351
- d_kv,
352
- ],
353
- "float32",
354
- ),
355
- ("block_tables", [1], "int16"),
356
- ]
357
-
358
- dec_input_info = [
359
- ("input_ids", [rbln_batch_size, 1], "int64"),
360
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
361
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
362
- (
363
- "cache_position",
364
- [rbln_batch_size, 1],
365
- "int32",
366
- ),
367
- ]
368
- dec_input_info.extend(
369
- [
370
- (
371
- "cross_key_value_states",
372
- [
373
- n_layer * 2,
374
- rbln_batch_size,
375
- n_head,
376
- rbln_enc_max_seq_len,
377
- d_kv,
378
- ],
379
- "float32",
380
- )
381
- ]
382
- )
383
- dec_input_info.extend(
384
- [
385
- (
386
- f"self_key_value_states_{i}",
387
- [
388
- rbln_batch_size,
389
- n_head,
390
- rbln_dec_max_seq_len,
391
- d_kv,
392
- ],
393
- "float32",
394
- )
395
- for i in range(n_layer * 2)
396
- ]
397
- )
398
-
399
- enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
400
- dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
401
-
402
- rbln_config = RBLNConfig(
403
- rbln_cls=cls.__name__,
404
- compile_cfgs=[enc_compile_config, dec_compile_config],
405
- rbln_kwargs=rbln_kwargs,
406
- )
407
-
408
- rbln_config.model_cfg.update(
409
- {
410
- "enc_max_seq_len": rbln_enc_max_seq_len,
411
- "dec_max_seq_len": rbln_dec_max_seq_len,
412
- "batch_size": rbln_batch_size,
413
- "pad_token_id": rbln_pad_token_id,
414
- }
415
- )
416
-
417
- return rbln_config
@@ -18,7 +18,6 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_add_softmax_attention
22
21
  from ..seq2seq.seq2seq_architecture import (
23
22
  Seq2SeqDecoder,
24
23
  Seq2SeqDecoderLayer,
@@ -55,7 +54,6 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
55
54
 
56
55
  class T5DecoderWrapper(Seq2SeqDecoderWrapper):
57
56
  def __post_init__(self, model, dec_max_seq_len: int = None):
58
- register_rbln_custom_add_softmax_attention()
59
57
  self.num_layers = self.config.num_layers
60
58
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
61
59
 
@@ -77,11 +75,13 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
77
75
  attention_mask,
78
76
  encoder_attention_mask,
79
77
  cache_position,
80
- cross_kv_cache,
81
- *self_kv_cache,
78
+ block_tables,
79
+ *kv_cache,
82
80
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
83
81
  self_past_key_values = ()
84
82
  cross_past_key_values = ()
83
+ self_kv_cache = kv_cache[self.num_layers * 2 :]
84
+ cross_kv_cache = kv_cache[: self.num_layers * 2]
85
85
 
86
86
  for i in range(0, self.num_layers * 2, 2):
87
87
  self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
@@ -95,6 +95,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
95
95
  self_past_key_values=self_past_key_values,
96
96
  cross_past_key_values=cross_past_key_values,
97
97
  cache_position=cache_position,
98
+ block_tables=block_tables,
98
99
  )
99
100
 
100
101
  return lm_logits
@@ -162,7 +163,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
162
163
  self.out_proj = self._original_mod.o
163
164
  self.num_heads = self._original_mod.n_heads
164
165
  self.head_dim = self._original_mod.key_value_proj_dim
165
- self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
166
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
166
167
 
167
168
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
168
169
  query_states = self.q_proj(hidden_states)
@@ -176,6 +177,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
176
177
  past_key_value: Tuple[torch.Tensor],
177
178
  attention_mask: torch.Tensor,
178
179
  cache_position: torch.Tensor,
180
+ block_tables: torch.Tensor,
179
181
  **kwargs,
180
182
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
181
183
  bsz, tgt_len, _ = hidden_states.size()
@@ -185,6 +187,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
185
187
  key_states = self._shape(key_states, -1, bsz)
186
188
  value_states = self._shape(value_states, -1, bsz)
187
189
 
190
+ block_size = past_key_value[0].shape[-2]
188
191
  attn_output = self.attn_decode(
189
192
  query_states,
190
193
  key_states,
@@ -196,6 +199,8 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
196
199
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
197
200
  cache_position,
198
201
  torch.tensor(1.0, dtype=torch.float32), # scale
202
+ block_tables,
203
+ block_size,
199
204
  )
200
205
 
201
206
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from ....ops import paged_add_softmax_attn_decode, rbln_cache_update
25
+ from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
26
+ from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction
@@ -0,0 +1,34 @@
1
+ from typing import Optional
2
+
3
+ from ....configuration_utils import RBLNModelConfig
4
+
5
+
6
+ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
7
+ def __init__(
8
+ self,
9
+ batch_size: Optional[int] = None,
10
+ enc_max_seq_len: Optional[int] = None,
11
+ dec_max_seq_len: Optional[int] = None,
12
+ num_parallel_samples: Optional[int] = None,
13
+ **kwargs,
14
+ ):
15
+ """
16
+ Args:
17
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
18
+ enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
19
+ dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
20
+ num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
21
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
22
+
23
+ Raises:
24
+ ValueError: If batch_size is not a positive integer.
25
+ """
26
+ super().__init__(**kwargs)
27
+
28
+ self.batch_size = batch_size or 1
29
+ if not isinstance(self.batch_size, int) or self.batch_size <= 0:
30
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
31
+
32
+ self.enc_max_seq_len = enc_max_seq_len
33
+ self.dec_max_seq_len = dec_max_seq_len
34
+ self.num_parallel_samples = num_parallel_samples