optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .streamers import BatchTextIteratorStreamer
@@ -0,0 +1,122 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import List, Optional
25
+
26
+ import torch
27
+ from transformers import AutoTokenizer, TextIteratorStreamer
28
+
29
+
30
+ class BatchTextIteratorStreamer(TextIteratorStreamer):
31
+ """
32
+ Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
33
+ useful for applications that benefit from accessing the generated text in a non-blocking way (e.g., in an interactive
34
+ Gradio demo).
35
+
36
+ This iterator extends TextIteratorStreamer to support batching of text generation. Each put operation appends
37
+ generated text to a batch, and the end operation finalizes the batch by processing and storing the generated
38
+ sequences.
39
+
40
+ Parameters:
41
+ batch_size (int):
42
+ The size of each text generation batch.
43
+ tokenizer (AutoTokenizer):
44
+ The tokenizer used to decode the tokens.
45
+ skip_prompt (bool, optional, default=False):
46
+ Whether to skip the prompt to `.generate()` or not. Useful, for example, for chatbots.
47
+ timeout (float, optional):
48
+ The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
49
+ in `.generate()` when it is called in a separate thread.
50
+ **decode_kwargs (dict, optional):
51
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
52
+
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ batch_size: int,
58
+ tokenizer: "AutoTokenizer",
59
+ skip_prompt: bool = False,
60
+ timeout: Optional[float] = None,
61
+ **decode_kwargs,
62
+ ):
63
+ super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
64
+ self.batch_size: int = batch_size
65
+ self.token_cache: List[List[int]] = [[] for _ in range(batch_size)]
66
+ self.print_len = [0] * batch_size
67
+
68
+ def put(self, value):
69
+ """
70
+ Receives tokens, decodes them, and prints them to buffer as soon as they form entire words.
71
+ """
72
+ if len(value.shape) < 2:
73
+ value = torch.reshape(value, (self.batch_size, value.shape[0] // self.batch_size))
74
+
75
+ if self.skip_prompt and self.next_tokens_are_prompt:
76
+ self.next_tokens_are_prompt = False
77
+ return
78
+
79
+ batch_printable_text = []
80
+ for i in range(self.batch_size):
81
+ # Add the new token to the cache and decodes the entire thing
82
+ self.token_cache[i].extend(value[i].tolist())
83
+ text = self.tokenizer.decode(self.token_cache[i], **self.decode_kwargs)
84
+
85
+ # After the symbol for a new line, we flush the cache.
86
+ if text.endswith("\n"):
87
+ printable_text = text[self.print_len[i] :]
88
+ self.token_cache[i] = []
89
+ self.print_len[i] = 0
90
+ # If the last token is a CJK character, we print the characters.
91
+ elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
92
+ printable_text = text[self.print_len[i] :]
93
+ self.print_len[i] += len(printable_text)
94
+ # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
95
+ # which may change with the subsequent token -- there are probably smarter ways to do this!)
96
+ else:
97
+ printable_text = text[self.print_len[i] : text.rfind(" ") + 1]
98
+ self.print_len[i] += len(printable_text)
99
+ batch_printable_text.append(printable_text)
100
+
101
+ self.on_finalized_text(batch_printable_text)
102
+
103
+ def end(self):
104
+ """Flushes any remaining cache and prints a newline to stdout."""
105
+ batch_printable_text = []
106
+ for idx in range(self.batch_size):
107
+ if len(self.token_cache[idx]) > 0:
108
+ text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
109
+ printable_text = text[self.print_len[idx] :]
110
+ self.token_cache[idx] = []
111
+ self.print_len[idx] = 0
112
+ else:
113
+ printable_text = ""
114
+ batch_printable_text.append(printable_text)
115
+
116
+ self.next_tokens_are_prompt = True
117
+ self.on_finalized_text(batch_printable_text, stream_end=True)
118
+
119
+ def on_finalized_text(self, texts: List[str], stream_end: bool = False):
120
+ self.text_queue.put(texts, timeout=self.timeout)
121
+ if stream_end:
122
+ self.text_queue.put(self.stop_signal, timeout=self.timeout)
@@ -0,0 +1,28 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
25
+ from .gpt2 import RBLNGPT2LMHeadModel
26
+ from .llama import RBLNLlamaForCausalLM
27
+ from .wav2vec2 import RBLNWav2Vec2ForCTC
28
+ from .whisper import RBLNWhisperForConditionalGeneration
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .bart_architecture import BartDecoderWrapper, BartEncoderWrapper
@@ -0,0 +1,377 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Optional, Tuple
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.modeling_attn_mask_utils import (
29
+ _prepare_4d_attention_mask,
30
+ _prepare_4d_attention_mask_for_sdpa,
31
+ _prepare_4d_causal_attention_mask,
32
+ _prepare_4d_causal_attention_mask_for_sdpa,
33
+ )
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPastAndCrossAttentions,
36
+ )
37
+ from transformers.models.bart.modeling_bart import (
38
+ BartAttention,
39
+ BartDecoder,
40
+ BartDecoderLayer,
41
+ BartForConditionalGeneration,
42
+ BartSdpaAttention,
43
+ )
44
+ from transformers.utils import logging
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ class _BartAttention(BartAttention):
51
+ def forward(
52
+ self,
53
+ hidden_states: torch.Tensor,
54
+ past_key_value: Tuple[torch.Tensor],
55
+ attention_mask: torch.Tensor,
56
+ cache_position: torch.Tensor,
57
+ key_value_states: Optional[torch.Tensor] = None,
58
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
59
+
60
+ bsz, tgt_len, _ = hidden_states.size()
61
+ is_cross_attention = key_value_states is not None
62
+
63
+ query_states = self.q_proj(hidden_states) * self.scaling
64
+
65
+ if is_cross_attention:
66
+ is_dummy_decoder = len(key_value_states.shape) > 1
67
+ if is_dummy_decoder:
68
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
69
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
70
+ else:
71
+ key_states = past_key_value[0]
72
+ value_states = past_key_value[1]
73
+ else:
74
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
75
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
76
+ key_states = past_key_value[0].slice_scatter(
77
+ key_states, dim=2, start=cache_position, end=cache_position + 1
78
+ )
79
+ value_states = past_key_value[1].slice_scatter(
80
+ value_states, dim=2, start=cache_position, end=cache_position + 1
81
+ )
82
+
83
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
84
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
85
+ key_states = key_states.reshape(*proj_shape)
86
+ value_states = value_states.reshape(*proj_shape)
87
+
88
+ src_len = key_states.size(1)
89
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
90
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
91
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
92
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
93
+
94
+ attn_output = torch.bmm(attn_weights, value_states)
95
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
96
+ attn_output = attn_output.transpose(1, 2)
97
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
98
+ attn_output = self.out_proj(attn_output)
99
+
100
+ present_key_value = (key_states, value_states)
101
+
102
+ return attn_output, present_key_value
103
+
104
+
105
+ class _BartSdpaAttention(BartSdpaAttention):
106
+ def forward(
107
+ self,
108
+ hidden_states: torch.Tensor,
109
+ past_key_value: Tuple[torch.Tensor],
110
+ attention_mask: torch.Tensor,
111
+ cache_position: torch.Tensor,
112
+ key_value_states: Optional[torch.Tensor] = None,
113
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
114
+
115
+ bsz, tgt_len, _ = hidden_states.size()
116
+ is_cross_attention = key_value_states is not None
117
+
118
+ query_states = self.q_proj(hidden_states)
119
+
120
+ if is_cross_attention:
121
+ is_dummy_decoder = len(key_value_states.shape) > 1
122
+ if is_dummy_decoder:
123
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
124
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
125
+ else:
126
+ key_states = past_key_value[0]
127
+ value_states = past_key_value[1]
128
+ else:
129
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
130
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
131
+ key_states = past_key_value[0].slice_scatter(
132
+ key_states, dim=2, start=cache_position, end=cache_position + 1
133
+ )
134
+ value_states = past_key_value[1].slice_scatter(
135
+ value_states, dim=2, start=cache_position, end=cache_position + 1
136
+ )
137
+
138
+ query_states = self._shape(query_states, tgt_len, bsz)
139
+
140
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
141
+ query_states,
142
+ key_states,
143
+ value_states,
144
+ attn_mask=attention_mask,
145
+ )
146
+ attn_output = attn_output.transpose(1, 2)
147
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
148
+ attn_output = self.out_proj(attn_output)
149
+
150
+ present_key_value = (key_states, value_states)
151
+
152
+ return attn_output, present_key_value
153
+
154
+
155
+ ATTN_FORWARD_MAP = {"eager": _BartAttention.forward, "sdpa": _BartSdpaAttention.forward}
156
+
157
+
158
+ class _BartDecoderLayer(BartDecoderLayer):
159
+ def forward(
160
+ self,
161
+ hidden_states: torch.Tensor,
162
+ attention_mask: torch.Tensor,
163
+ encoder_attention_mask: torch.Tensor,
164
+ encoder_hidden_states: torch.Tensor,
165
+ past_key_value: Tuple[torch.Tensor],
166
+ cache_position: torch.Tensor,
167
+ attn_impl: str = "eager",
168
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
169
+
170
+ # Self Attention Block
171
+ residual = hidden_states
172
+ self_attn_past_key_value = past_key_value[:2]
173
+
174
+ hidden_states, present_key_value = ATTN_FORWARD_MAP[attn_impl](
175
+ self.self_attn,
176
+ hidden_states=hidden_states,
177
+ past_key_value=self_attn_past_key_value,
178
+ attention_mask=attention_mask,
179
+ cache_position=cache_position,
180
+ )
181
+ hidden_states = residual + hidden_states
182
+ hidden_states = self.self_attn_layer_norm(hidden_states)
183
+
184
+ # Cross-Attention Block
185
+ residual = hidden_states
186
+ cross_attn_past_key_value = past_key_value[-2:]
187
+
188
+ hidden_states, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
189
+ self.encoder_attn,
190
+ hidden_states=hidden_states,
191
+ key_value_states=encoder_hidden_states,
192
+ past_key_value=cross_attn_past_key_value,
193
+ attention_mask=encoder_attention_mask,
194
+ cache_position=cache_position,
195
+ )
196
+ hidden_states = residual + hidden_states
197
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
198
+ present_key_value = present_key_value + cross_attn_present_key_value
199
+
200
+ # Fully Connected Block
201
+ residual = hidden_states
202
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
203
+ hidden_states = self.fc2(hidden_states)
204
+ hidden_states = residual + hidden_states
205
+ hidden_states = self.final_layer_norm(hidden_states)
206
+
207
+ return hidden_states, present_key_value
208
+
209
+
210
+ class _BartDecoder(BartDecoder):
211
+ def forward(
212
+ self,
213
+ input_ids: torch.Tensor,
214
+ attention_mask: torch.Tensor,
215
+ encoder_attention_mask: torch.Tensor,
216
+ encoder_hidden_states: torch.Tensor,
217
+ past_key_values: torch.Tensor,
218
+ cache_position: torch.Tensor,
219
+ attn_impl: str = "eager",
220
+ ):
221
+
222
+ # embedding
223
+ positions_idx = cache_position + self.embed_positions.offset
224
+ positions = self.embed_positions.weight[positions_idx]
225
+
226
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
227
+ hidden_states = inputs_embeds + positions
228
+ hidden_states = self.layernorm_embedding(hidden_states)
229
+
230
+ # prepare attn_mask
231
+ input_shape = input_ids.size()
232
+ if self._use_sdpa:
233
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
234
+ attention_mask, input_shape, inputs_embeds, cache_position
235
+ )
236
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
237
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
238
+ )
239
+ else:
240
+ attention_mask = _prepare_4d_causal_attention_mask(
241
+ attention_mask, input_shape, inputs_embeds, cache_position
242
+ )
243
+ encoder_attention_mask = _prepare_4d_attention_mask(
244
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
245
+ )
246
+
247
+ # iterate decoder_layer
248
+ next_decoder_cache = ()
249
+ for idx, decoder_layer in enumerate(self.layers):
250
+ past_key_value = past_key_values[idx]
251
+ layer_outputs = _BartDecoderLayer.forward(
252
+ decoder_layer,
253
+ hidden_states,
254
+ attention_mask=attention_mask,
255
+ encoder_hidden_states=encoder_hidden_states,
256
+ encoder_attention_mask=encoder_attention_mask,
257
+ past_key_value=past_key_value,
258
+ cache_position=cache_position,
259
+ attn_impl=attn_impl,
260
+ )
261
+ hidden_states = layer_outputs[0]
262
+ next_decoder_cache += (layer_outputs[1],)
263
+
264
+ return BaseModelOutputWithPastAndCrossAttentions(
265
+ last_hidden_state=hidden_states,
266
+ past_key_values=next_decoder_cache,
267
+ )
268
+
269
+
270
+ class BartDecoderWrapper(torch.nn.Module):
271
+ def __init__(self, model: "BartForConditionalGeneration"):
272
+ super().__init__()
273
+ self.config = model.config
274
+ self.decoder = model.get_decoder()
275
+ self.num_layers = self.config.decoder_layers
276
+ self.lm_head = model.lm_head
277
+
278
+ def forward(
279
+ self,
280
+ input_ids: torch.Tensor,
281
+ attention_mask: torch.Tensor,
282
+ encoder_attention_mask: torch.Tensor,
283
+ cache_position: torch.Tensor,
284
+ self_kv_cache: torch.Tensor,
285
+ cross_kv_cache: torch.Tensor,
286
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
287
+
288
+ # prepare past_key_values
289
+ kv_cache = ()
290
+ for i in range(0, self.num_layers * 2, 2):
291
+ kv_cache = kv_cache + (
292
+ (
293
+ self_kv_cache[i],
294
+ self_kv_cache[i + 1],
295
+ cross_kv_cache[i],
296
+ cross_kv_cache[i + 1],
297
+ ),
298
+ )
299
+
300
+ # decode
301
+ decoder_outputs = _BartDecoder.forward(
302
+ self.decoder,
303
+ input_ids=input_ids,
304
+ attention_mask=attention_mask,
305
+ encoder_attention_mask=encoder_attention_mask,
306
+ cache_position=cache_position,
307
+ past_key_values=kv_cache,
308
+ encoder_hidden_states=torch.tensor([1]),
309
+ attn_impl=self.config._attn_implementation,
310
+ )
311
+ sequence_output = decoder_outputs[0]
312
+ lm_logits = self.lm_head(sequence_output)
313
+
314
+ # get self_kv_cache from ouputs
315
+ past_key_values = decoder_outputs[1]
316
+ self_kv_cache = []
317
+ for i in range(self.num_layers):
318
+ self_kv_cache.append(past_key_values[i][0])
319
+ self_kv_cache.append(past_key_values[i][1])
320
+ self_kv_cache = torch.stack(self_kv_cache, dim=0)
321
+
322
+ return lm_logits, self_kv_cache
323
+
324
+
325
+ class BartEncoderWrapper(torch.nn.Module):
326
+ def __init__(self, model):
327
+ super().__init__()
328
+ self.model = model
329
+ self.config = model.config
330
+ self.decoder = model.get_decoder()
331
+ self.encoder = model.get_encoder()
332
+ self.num_layers = self.config.encoder_layers
333
+ self.decoder_max_length = self.config.max_position_embeddings
334
+ self.encoder_max_length = self.config.max_position_embeddings
335
+ self.num_heads = self.config.decoder_attention_heads
336
+ self.d_kv = self.config.d_model // self.num_heads
337
+
338
+ def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> Tuple[torch.Tensor]:
339
+ encoder_batch_size = input_ids.shape[0]
340
+ decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
341
+
342
+ # 1. run encoder
343
+ encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
344
+ last_hidden_states = encoder_outputs[0]
345
+
346
+ # 2. run dummy decoder to get pre-calculated cross-key_values for generation
347
+ dummy_past_key_value = []
348
+ for _ in range(self.num_layers):
349
+ pkv_self_attn_key = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
350
+ pkv_self_attn_value = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
351
+ pkv_cross_attn_key = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
352
+ pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
353
+ layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
354
+ dummy_past_key_value.append(layer_pkv)
355
+
356
+ decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.int64)
357
+ decoder_attention_mask[:, :1] = 1
358
+
359
+ decoder_outputs = _BartDecoder.forward(
360
+ self.decoder,
361
+ input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
362
+ attention_mask=decoder_attention_mask,
363
+ encoder_attention_mask=attention_mask,
364
+ cache_position=torch.tensor(0, dtype=torch.int32),
365
+ encoder_hidden_states=last_hidden_states,
366
+ past_key_values=dummy_past_key_value,
367
+ attn_impl=self.config._attn_implementation,
368
+ )
369
+ first_past_kv = decoder_outputs[1]
370
+
371
+ # 3. return cross_key_values to recurrence port. fyi (enc_ir.outputs[0] -> dec_ir.inputs[5])
372
+ encoder_kv = []
373
+ for layer_out in first_past_kv: # for layer
374
+ encoder_kv.append(torch.stack(layer_out[2:], dim=0))
375
+ encoder_kv = torch.stack(encoder_kv, dim=0)
376
+
377
+ return encoder_kv
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection