optimum-rbln 0.7.3.post1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +11 -86
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -118
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +23 -151
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post1.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -18,12 +18,6 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import (
22
- register_rbln_custom_cache_update,
23
- register_rbln_custom_paged_attention,
24
- register_rbln_custom_paged_causal_attention,
25
- )
26
-
27
21
 
28
22
  logger = logging.get_logger(__name__)
29
23
 
@@ -59,7 +53,6 @@ class Seq2SeqEncoderWrapper(nn.Module):
59
53
 
60
54
  def __init__(self, model: nn.Module, enc_max_seq_len: int):
61
55
  super().__init__()
62
- register_rbln_custom_cache_update()
63
56
  self.config = model.config
64
57
  self.encoder = model.get_encoder()
65
58
  self.encoder_max_length = enc_max_seq_len
@@ -90,8 +83,8 @@ class Seq2SeqEncoderWrapper(nn.Module):
90
83
  self,
91
84
  input_ids: torch.Tensor,
92
85
  attention_mask: torch.Tensor,
93
- cross_key_values: torch.Tensor,
94
86
  b_idx: torch.Tensor,
87
+ *cross_key_values: Tuple[torch.Tensor],
95
88
  ) -> Tuple[torch.Tensor]:
96
89
  # 1. get encoder last_hidden_states
97
90
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
@@ -110,13 +103,15 @@ class Seq2SeqEncoderWrapper(nn.Module):
110
103
  cross_kv.append(past_k)
111
104
  cross_kv.append(past_v)
112
105
 
113
- cross_kv = torch.stack(cross_kv, dim=0)
114
-
115
106
  # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
116
- batch_axis = torch.tensor(1, dtype=torch.int16)
117
- enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
107
+ batch_axis = torch.tensor(0, dtype=torch.int16)
108
+ cross_key_values = list(cross_key_values)
109
+ for i in range(self.n_layer * 2):
110
+ cross_key_values[i] = torch.ops.rbln_custom_ops.rbln_cache_update(
111
+ cross_key_values[i], cross_kv[i], b_idx[0], batch_axis
112
+ )
118
113
 
119
- return enc_out
114
+ return cross_key_values
120
115
 
121
116
 
122
117
  class Seq2SeqDecoderWrapper(nn.Module):
@@ -146,11 +141,6 @@ class Seq2SeqDecoderWrapper(nn.Module):
146
141
  It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
147
142
  by subclasses to modify or add custom attributes as necessary.
148
143
  """
149
- if self.use_attention_mask:
150
- register_rbln_custom_paged_attention()
151
- else:
152
- register_rbln_custom_paged_causal_attention()
153
-
154
144
  self.num_layers = self.config.decoder_layers
155
145
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
156
146
 
@@ -176,16 +166,17 @@ class Seq2SeqDecoderWrapper(nn.Module):
176
166
  encoder_attention_mask,
177
167
  cache_position,
178
168
  block_tables,
179
- cross_kv_cache,
180
- *self_kv_cache,
169
+ *kv_cache,
181
170
  ) = args
182
171
 
183
172
  else:
184
173
  attention_mask = None
185
- (input_ids, encoder_attention_mask, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
174
+ (input_ids, encoder_attention_mask, cache_position, block_tables, *kv_cache) = args
186
175
 
187
176
  self_past_key_values = ()
188
177
  cross_past_key_values = ()
178
+ self_kv_cache = kv_cache[self.num_layers * 2 :]
179
+ cross_kv_cache = kv_cache[: self.num_layers * 2]
189
180
  for i in range(0, self.num_layers * 2, 2):
190
181
  self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
191
182
  cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
@@ -12,4 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import paged_add_softmax_attn_decode
16
+ from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
15
17
  from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ pass
21
+
22
+
23
+ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
+ pass
@@ -13,48 +13,21 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable
17
17
 
18
18
  import torch
19
- from transformers import (
20
- AutoModelForTextEncoding,
21
- PretrainedConfig,
22
- T5EncoderModel,
23
- T5ForConditionalGeneration,
24
- )
25
- from transformers.modeling_outputs import BaseModelOutput
26
-
27
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
28
- from ....modeling import RBLNModel
29
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
30
- from ....utils.logging import get_logger
31
- from ....utils.runtime_utils import RBLNPytorchRuntime
19
+ from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
20
+
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
32
22
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
23
+ from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
33
24
  from .t5_architecture import T5Wrapper
34
25
 
35
26
 
36
- logger = get_logger()
37
-
38
27
  if TYPE_CHECKING:
39
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
40
-
41
-
42
- class RBLNRuntimeModel(RBLNPytorchRuntime):
43
- def forward(
44
- self,
45
- input_ids: torch.LongTensor,
46
- attention_mask: torch.FloatTensor,
47
- head_mask: torch.FloatTensor,
48
- inputs_embeds: torch.FloatTensor,
49
- **kwargs,
50
- ):
51
- return super().forward(
52
- input_ids,
53
- attention_mask,
54
- head_mask,
55
- inputs_embeds,
56
- **kwargs,
57
- )
28
+ from transformers import PreTrainedModel
29
+
30
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
58
31
 
59
32
 
60
33
  class T5EncoderWrapper(torch.nn.Module):
@@ -67,136 +40,35 @@ class T5EncoderWrapper(torch.nn.Module):
67
40
  return self.model(*args, **kwargs, return_dict=False)
68
41
 
69
42
 
70
- class RBLNT5EncoderModel(RBLNModel):
43
+ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
71
44
  auto_model_class = AutoModelForTextEncoding
72
45
  rbln_model_input_names = ["input_ids", "attention_mask"]
73
46
 
74
- def __post_init__(self, **kwargs):
75
- self.model = RBLNRuntimeModel(runtime=self.model[0])
76
-
77
47
  @classmethod
78
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
48
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
79
49
  return T5EncoderWrapper(model)
80
50
 
81
51
  @classmethod
82
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
83
- batch_size = rbln_config.get("batch_size", 1)
84
- max_sequence_length = rbln_config.get("max_sequence_length", 256)
85
- model_input_names = ["input_ids"]
86
-
87
- rbln_config.update(
88
- {
89
- "batch_size": batch_size,
90
- "max_seq_len": max_sequence_length,
91
- "model_input_names": model_input_names,
92
- }
93
- )
94
-
95
- return rbln_config
96
-
97
- @classmethod
98
- def _get_rbln_config(
52
+ def update_rbln_config_using_pipe(
99
53
  cls,
100
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
101
- model_config: Optional["PretrainedConfig"] = None,
102
- rbln_kwargs: Dict[str, Any] = {},
103
- ) -> RBLNConfig:
104
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
105
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
106
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
107
-
108
- max_position_embeddings = getattr(model_config, "n_positions", None)
109
-
110
- if rbln_max_seq_len is None:
111
- rbln_max_seq_len = max_position_embeddings
112
- if rbln_max_seq_len is None:
113
- for tokenizer in preprocessors:
114
- if hasattr(tokenizer, "model_max_length"):
115
- rbln_max_seq_len = tokenizer.model_max_length
116
- break
117
- if rbln_max_seq_len is None:
118
- raise ValueError("`rbln_max_seq_len` should be specified!")
119
-
120
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
121
- raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
122
-
123
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
124
-
125
- if rbln_model_input_names is None:
126
- for tokenizer in preprocessors:
127
- if hasattr(tokenizer, "model_input_names"):
128
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
129
-
130
- invalid_params = set(rbln_model_input_names) - set(signature_params)
131
- if invalid_params:
132
- raise ValueError(f"Invalid model input names: {invalid_params}")
133
- break
134
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
135
- rbln_model_input_names = cls.rbln_model_input_names
136
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
137
- raise ValueError(
138
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
139
- f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(signature_params)})"
140
- )
141
- else:
142
- invalid_params = set(rbln_model_input_names) - set(signature_params)
143
- if invalid_params:
144
- raise ValueError(f"Invalid model input names: {invalid_params}")
145
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
146
-
147
- if rbln_batch_size is None:
148
- rbln_batch_size = 1
149
-
150
- input_info = [
151
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
152
- for model_input_name in rbln_model_input_names
153
- ]
154
-
155
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
156
-
157
- rbln_config = RBLNConfig(
158
- rbln_cls=cls.__name__,
159
- compile_cfgs=[rbln_compile_config],
160
- rbln_kwargs=rbln_kwargs,
161
- )
162
-
163
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
54
+ pipe: "RBLNDiffusionMixin",
55
+ rbln_config: "RBLNDiffusionMixinConfig",
56
+ submodule_name: str,
57
+ ) -> "RBLNDiffusionMixinConfig":
58
+ submodule_config = getattr(rbln_config, submodule_name)
59
+ submodule_config.max_seq_len = rbln_config.max_seq_len or 256
60
+ submodule_config.model_input_names = ["input_ids"]
164
61
  return rbln_config
165
62
 
166
- def forward(
167
- self,
168
- input_ids: Optional[torch.LongTensor] = None,
169
- attention_mask: Optional[torch.FloatTensor] = None,
170
- head_mask: Optional[torch.FloatTensor] = None,
171
- inputs_embeds: Optional[torch.FloatTensor] = None,
172
- output_attentions: Optional[bool] = None,
173
- output_hidden_states: Optional[bool] = None,
174
- return_dict: Optional[bool] = None,
175
- ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
176
- encoder_outputs = self.model(
177
- input_ids=input_ids,
178
- attention_mask=attention_mask,
179
- inputs_embeds=inputs_embeds,
180
- head_mask=head_mask,
181
- output_attentions=output_attentions,
182
- output_hidden_states=output_hidden_states,
183
- return_dict=return_dict,
184
- )
185
- if not return_dict:
186
- return (encoder_outputs,)
187
- else:
188
- return BaseModelOutput(last_hidden_state=encoder_outputs)
189
-
190
63
 
191
64
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
192
- support_causal_paged_attn = False
65
+ support_causal_attn = False
193
66
 
194
67
  @classmethod
195
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
196
- enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
197
- dec_max_seq_len = rbln_config.model_cfg["dec_max_seq_len"]
198
-
199
- return T5Wrapper(model, enc_max_seq_len=enc_max_seq_len, dec_max_seq_len=dec_max_seq_len)
68
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
69
+ return T5Wrapper(
70
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
71
+ )
200
72
 
201
73
  def __getattr__(self, __name: str) -> Any:
202
74
  def redirect(func):
@@ -18,7 +18,6 @@ import torch
18
18
  from torch import nn
19
19
  from transformers.utils import logging
20
20
 
21
- from ....ops import register_rbln_custom_add_softmax_attention
22
21
  from ..seq2seq.seq2seq_architecture import (
23
22
  Seq2SeqDecoder,
24
23
  Seq2SeqDecoderLayer,
@@ -55,7 +54,6 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
55
54
 
56
55
  class T5DecoderWrapper(Seq2SeqDecoderWrapper):
57
56
  def __post_init__(self, model, dec_max_seq_len: int = None):
58
- register_rbln_custom_add_softmax_attention()
59
57
  self.num_layers = self.config.num_layers
60
58
  self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
61
59
 
@@ -77,11 +75,13 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
77
75
  attention_mask,
78
76
  encoder_attention_mask,
79
77
  cache_position,
80
- cross_kv_cache,
81
- *self_kv_cache,
78
+ block_tables,
79
+ *kv_cache,
82
80
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
83
81
  self_past_key_values = ()
84
82
  cross_past_key_values = ()
83
+ self_kv_cache = kv_cache[self.num_layers * 2 :]
84
+ cross_kv_cache = kv_cache[: self.num_layers * 2]
85
85
 
86
86
  for i in range(0, self.num_layers * 2, 2):
87
87
  self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
@@ -95,6 +95,7 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
95
95
  self_past_key_values=self_past_key_values,
96
96
  cross_past_key_values=cross_past_key_values,
97
97
  cache_position=cache_position,
98
+ block_tables=block_tables,
98
99
  )
99
100
 
100
101
  return lm_logits
@@ -162,7 +163,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
162
163
  self.out_proj = self._original_mod.o
163
164
  self.num_heads = self._original_mod.n_heads
164
165
  self.head_dim = self._original_mod.key_value_proj_dim
165
- self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
166
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
166
167
 
167
168
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
168
169
  query_states = self.q_proj(hidden_states)
@@ -176,6 +177,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
176
177
  past_key_value: Tuple[torch.Tensor],
177
178
  attention_mask: torch.Tensor,
178
179
  cache_position: torch.Tensor,
180
+ block_tables: torch.Tensor,
179
181
  **kwargs,
180
182
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
181
183
  bsz, tgt_len, _ = hidden_states.size()
@@ -185,6 +187,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
185
187
  key_states = self._shape(key_states, -1, bsz)
186
188
  value_states = self._shape(value_states, -1, bsz)
187
189
 
190
+ block_size = past_key_value[0].shape[-2]
188
191
  attn_output = self.attn_decode(
189
192
  query_states,
190
193
  key_states,
@@ -196,6 +199,8 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
196
199
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
197
200
  cache_position,
198
201
  torch.tensor(1.0, dtype=torch.float32), # scale
202
+ block_tables,
203
+ block_size,
199
204
  )
200
205
 
201
206
  attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from ....ops import paged_add_softmax_attn_decode, rbln_cache_update
25
+ from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
26
+ from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction
@@ -0,0 +1,34 @@
1
+ from typing import Optional
2
+
3
+ from ....configuration_utils import RBLNModelConfig
4
+
5
+
6
+ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
7
+ def __init__(
8
+ self,
9
+ batch_size: Optional[int] = None,
10
+ enc_max_seq_len: Optional[int] = None,
11
+ dec_max_seq_len: Optional[int] = None,
12
+ num_parallel_samples: Optional[int] = None,
13
+ **kwargs,
14
+ ):
15
+ """
16
+ Args:
17
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
18
+ enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
19
+ dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
20
+ num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
21
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
22
+
23
+ Raises:
24
+ ValueError: If batch_size is not a positive integer.
25
+ """
26
+ super().__init__(**kwargs)
27
+
28
+ self.batch_size = batch_size or 1
29
+ if not isinstance(self.batch_size, int) or self.batch_size <= 0:
30
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
31
+
32
+ self.enc_max_seq_len = enc_max_seq_len
33
+ self.dec_max_seq_len = dec_max_seq_len
34
+ self.num_parallel_samples = num_parallel_samples