optimum-rbln 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +115 -0
- optimum/rbln/__version__.py +1 -0
- optimum/rbln/diffusers/__init__.py +64 -0
- optimum/rbln/diffusers/models/__init__.py +26 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
- optimum/rbln/diffusers/models/controlnet.py +180 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
- optimum/rbln/diffusers/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
- optimum/rbln/modeling.py +0 -0
- optimum/rbln/modeling_alias.py +49 -0
- optimum/rbln/modeling_base.py +645 -0
- optimum/rbln/modeling_config.py +169 -0
- optimum/rbln/modeling_seq2seq.py +469 -0
- optimum/rbln/transformers/__init__.py +59 -0
- optimum/rbln/transformers/generation/__init__.py +24 -0
- optimum/rbln/transformers/generation/streamers.py +122 -0
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/__init__.py +24 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
- optimum/rbln/transformers/models/clip/__init__.py +24 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
- optimum/rbln/transformers/models/llama/__init__.py +24 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
- optimum/rbln/transformers/models/t5/__init__.py +24 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
- optimum/rbln/transformers/models/whisper/__init__.py +24 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
- optimum/rbln/utils/__init__.py +25 -0
- optimum/rbln/utils/import_utils.py +28 -0
- optimum/rbln/utils/runtime_utils.py +71 -0
- optimum/rbln/utils/save_utils.py +92 -0
- optimum_rbln-0.1.0.dist-info/METADATA +144 -0
- optimum_rbln-0.1.0.dist-info/RECORD +51 -0
- optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
- optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .streamers import BatchTextIteratorStreamer
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import List, Optional
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from transformers import AutoTokenizer, TextIteratorStreamer
|
28
|
+
|
29
|
+
|
30
|
+
class BatchTextIteratorStreamer(TextIteratorStreamer):
|
31
|
+
"""
|
32
|
+
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
|
33
|
+
useful for applications that benefit from accessing the generated text in a non-blocking way (e.g., in an interactive
|
34
|
+
Gradio demo).
|
35
|
+
|
36
|
+
This iterator extends TextIteratorStreamer to support batching of text generation. Each put operation appends
|
37
|
+
generated text to a batch, and the end operation finalizes the batch by processing and storing the generated
|
38
|
+
sequences.
|
39
|
+
|
40
|
+
Parameters:
|
41
|
+
batch_size (int):
|
42
|
+
The size of each text generation batch.
|
43
|
+
tokenizer (AutoTokenizer):
|
44
|
+
The tokenizer used to decode the tokens.
|
45
|
+
skip_prompt (bool, optional, default=False):
|
46
|
+
Whether to skip the prompt to `.generate()` or not. Useful, for example, for chatbots.
|
47
|
+
timeout (float, optional):
|
48
|
+
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
49
|
+
in `.generate()` when it is called in a separate thread.
|
50
|
+
**decode_kwargs (dict, optional):
|
51
|
+
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
52
|
+
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
batch_size: int,
|
58
|
+
tokenizer: "AutoTokenizer",
|
59
|
+
skip_prompt: bool = False,
|
60
|
+
timeout: Optional[float] = None,
|
61
|
+
**decode_kwargs,
|
62
|
+
):
|
63
|
+
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
|
64
|
+
self.batch_size: int = batch_size
|
65
|
+
self.token_cache: List[List[int]] = [[] for _ in range(batch_size)]
|
66
|
+
self.print_len = [0] * batch_size
|
67
|
+
|
68
|
+
def put(self, value):
|
69
|
+
"""
|
70
|
+
Receives tokens, decodes them, and prints them to buffer as soon as they form entire words.
|
71
|
+
"""
|
72
|
+
if len(value.shape) < 2:
|
73
|
+
value = torch.reshape(value, (self.batch_size, value.shape[0] // self.batch_size))
|
74
|
+
|
75
|
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
76
|
+
self.next_tokens_are_prompt = False
|
77
|
+
return
|
78
|
+
|
79
|
+
batch_printable_text = []
|
80
|
+
for i in range(self.batch_size):
|
81
|
+
# Add the new token to the cache and decodes the entire thing
|
82
|
+
self.token_cache[i].extend(value[i].tolist())
|
83
|
+
text = self.tokenizer.decode(self.token_cache[i], **self.decode_kwargs)
|
84
|
+
|
85
|
+
# After the symbol for a new line, we flush the cache.
|
86
|
+
if text.endswith("\n"):
|
87
|
+
printable_text = text[self.print_len[i] :]
|
88
|
+
self.token_cache[i] = []
|
89
|
+
self.print_len[i] = 0
|
90
|
+
# If the last token is a CJK character, we print the characters.
|
91
|
+
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
92
|
+
printable_text = text[self.print_len[i] :]
|
93
|
+
self.print_len[i] += len(printable_text)
|
94
|
+
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
95
|
+
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
96
|
+
else:
|
97
|
+
printable_text = text[self.print_len[i] : text.rfind(" ") + 1]
|
98
|
+
self.print_len[i] += len(printable_text)
|
99
|
+
batch_printable_text.append(printable_text)
|
100
|
+
|
101
|
+
self.on_finalized_text(batch_printable_text)
|
102
|
+
|
103
|
+
def end(self):
|
104
|
+
"""Flushes any remaining cache and prints a newline to stdout."""
|
105
|
+
batch_printable_text = []
|
106
|
+
for idx in range(self.batch_size):
|
107
|
+
if len(self.token_cache[idx]) > 0:
|
108
|
+
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
|
109
|
+
printable_text = text[self.print_len[idx] :]
|
110
|
+
self.token_cache[idx] = []
|
111
|
+
self.print_len[idx] = 0
|
112
|
+
else:
|
113
|
+
printable_text = ""
|
114
|
+
batch_printable_text.append(printable_text)
|
115
|
+
|
116
|
+
self.next_tokens_are_prompt = True
|
117
|
+
self.on_finalized_text(batch_printable_text, stream_end=True)
|
118
|
+
|
119
|
+
def on_finalized_text(self, texts: List[str], stream_end: bool = False):
|
120
|
+
self.text_queue.put(texts, timeout=self.timeout)
|
121
|
+
if stream_end:
|
122
|
+
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
25
|
+
from .gpt2 import RBLNGPT2LMHeadModel
|
26
|
+
from .llama import RBLNLlamaForCausalLM
|
27
|
+
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
28
|
+
from .whisper import RBLNWhisperForConditionalGeneration
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .bart_architecture import BartDecoderWrapper, BartEncoderWrapper
|
@@ -0,0 +1,377 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import Optional, Tuple
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
from transformers.modeling_attn_mask_utils import (
|
29
|
+
_prepare_4d_attention_mask,
|
30
|
+
_prepare_4d_attention_mask_for_sdpa,
|
31
|
+
_prepare_4d_causal_attention_mask,
|
32
|
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
33
|
+
)
|
34
|
+
from transformers.modeling_outputs import (
|
35
|
+
BaseModelOutputWithPastAndCrossAttentions,
|
36
|
+
)
|
37
|
+
from transformers.models.bart.modeling_bart import (
|
38
|
+
BartAttention,
|
39
|
+
BartDecoder,
|
40
|
+
BartDecoderLayer,
|
41
|
+
BartForConditionalGeneration,
|
42
|
+
BartSdpaAttention,
|
43
|
+
)
|
44
|
+
from transformers.utils import logging
|
45
|
+
|
46
|
+
|
47
|
+
logger = logging.get_logger(__name__)
|
48
|
+
|
49
|
+
|
50
|
+
class _BartAttention(BartAttention):
|
51
|
+
def forward(
|
52
|
+
self,
|
53
|
+
hidden_states: torch.Tensor,
|
54
|
+
past_key_value: Tuple[torch.Tensor],
|
55
|
+
attention_mask: torch.Tensor,
|
56
|
+
cache_position: torch.Tensor,
|
57
|
+
key_value_states: Optional[torch.Tensor] = None,
|
58
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
59
|
+
|
60
|
+
bsz, tgt_len, _ = hidden_states.size()
|
61
|
+
is_cross_attention = key_value_states is not None
|
62
|
+
|
63
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
64
|
+
|
65
|
+
if is_cross_attention:
|
66
|
+
is_dummy_decoder = len(key_value_states.shape) > 1
|
67
|
+
if is_dummy_decoder:
|
68
|
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
69
|
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
70
|
+
else:
|
71
|
+
key_states = past_key_value[0]
|
72
|
+
value_states = past_key_value[1]
|
73
|
+
else:
|
74
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
75
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
76
|
+
key_states = past_key_value[0].slice_scatter(
|
77
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
78
|
+
)
|
79
|
+
value_states = past_key_value[1].slice_scatter(
|
80
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
81
|
+
)
|
82
|
+
|
83
|
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
84
|
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
85
|
+
key_states = key_states.reshape(*proj_shape)
|
86
|
+
value_states = value_states.reshape(*proj_shape)
|
87
|
+
|
88
|
+
src_len = key_states.size(1)
|
89
|
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
90
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
91
|
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
92
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
93
|
+
|
94
|
+
attn_output = torch.bmm(attn_weights, value_states)
|
95
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
96
|
+
attn_output = attn_output.transpose(1, 2)
|
97
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
98
|
+
attn_output = self.out_proj(attn_output)
|
99
|
+
|
100
|
+
present_key_value = (key_states, value_states)
|
101
|
+
|
102
|
+
return attn_output, present_key_value
|
103
|
+
|
104
|
+
|
105
|
+
class _BartSdpaAttention(BartSdpaAttention):
|
106
|
+
def forward(
|
107
|
+
self,
|
108
|
+
hidden_states: torch.Tensor,
|
109
|
+
past_key_value: Tuple[torch.Tensor],
|
110
|
+
attention_mask: torch.Tensor,
|
111
|
+
cache_position: torch.Tensor,
|
112
|
+
key_value_states: Optional[torch.Tensor] = None,
|
113
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
114
|
+
|
115
|
+
bsz, tgt_len, _ = hidden_states.size()
|
116
|
+
is_cross_attention = key_value_states is not None
|
117
|
+
|
118
|
+
query_states = self.q_proj(hidden_states)
|
119
|
+
|
120
|
+
if is_cross_attention:
|
121
|
+
is_dummy_decoder = len(key_value_states.shape) > 1
|
122
|
+
if is_dummy_decoder:
|
123
|
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
124
|
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
125
|
+
else:
|
126
|
+
key_states = past_key_value[0]
|
127
|
+
value_states = past_key_value[1]
|
128
|
+
else:
|
129
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
130
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
131
|
+
key_states = past_key_value[0].slice_scatter(
|
132
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
133
|
+
)
|
134
|
+
value_states = past_key_value[1].slice_scatter(
|
135
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
136
|
+
)
|
137
|
+
|
138
|
+
query_states = self._shape(query_states, tgt_len, bsz)
|
139
|
+
|
140
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
141
|
+
query_states,
|
142
|
+
key_states,
|
143
|
+
value_states,
|
144
|
+
attn_mask=attention_mask,
|
145
|
+
)
|
146
|
+
attn_output = attn_output.transpose(1, 2)
|
147
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
148
|
+
attn_output = self.out_proj(attn_output)
|
149
|
+
|
150
|
+
present_key_value = (key_states, value_states)
|
151
|
+
|
152
|
+
return attn_output, present_key_value
|
153
|
+
|
154
|
+
|
155
|
+
ATTN_FORWARD_MAP = {"eager": _BartAttention.forward, "sdpa": _BartSdpaAttention.forward}
|
156
|
+
|
157
|
+
|
158
|
+
class _BartDecoderLayer(BartDecoderLayer):
|
159
|
+
def forward(
|
160
|
+
self,
|
161
|
+
hidden_states: torch.Tensor,
|
162
|
+
attention_mask: torch.Tensor,
|
163
|
+
encoder_attention_mask: torch.Tensor,
|
164
|
+
encoder_hidden_states: torch.Tensor,
|
165
|
+
past_key_value: Tuple[torch.Tensor],
|
166
|
+
cache_position: torch.Tensor,
|
167
|
+
attn_impl: str = "eager",
|
168
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
169
|
+
|
170
|
+
# Self Attention Block
|
171
|
+
residual = hidden_states
|
172
|
+
self_attn_past_key_value = past_key_value[:2]
|
173
|
+
|
174
|
+
hidden_states, present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
175
|
+
self.self_attn,
|
176
|
+
hidden_states=hidden_states,
|
177
|
+
past_key_value=self_attn_past_key_value,
|
178
|
+
attention_mask=attention_mask,
|
179
|
+
cache_position=cache_position,
|
180
|
+
)
|
181
|
+
hidden_states = residual + hidden_states
|
182
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
183
|
+
|
184
|
+
# Cross-Attention Block
|
185
|
+
residual = hidden_states
|
186
|
+
cross_attn_past_key_value = past_key_value[-2:]
|
187
|
+
|
188
|
+
hidden_states, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
189
|
+
self.encoder_attn,
|
190
|
+
hidden_states=hidden_states,
|
191
|
+
key_value_states=encoder_hidden_states,
|
192
|
+
past_key_value=cross_attn_past_key_value,
|
193
|
+
attention_mask=encoder_attention_mask,
|
194
|
+
cache_position=cache_position,
|
195
|
+
)
|
196
|
+
hidden_states = residual + hidden_states
|
197
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
198
|
+
present_key_value = present_key_value + cross_attn_present_key_value
|
199
|
+
|
200
|
+
# Fully Connected Block
|
201
|
+
residual = hidden_states
|
202
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
203
|
+
hidden_states = self.fc2(hidden_states)
|
204
|
+
hidden_states = residual + hidden_states
|
205
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
206
|
+
|
207
|
+
return hidden_states, present_key_value
|
208
|
+
|
209
|
+
|
210
|
+
class _BartDecoder(BartDecoder):
|
211
|
+
def forward(
|
212
|
+
self,
|
213
|
+
input_ids: torch.Tensor,
|
214
|
+
attention_mask: torch.Tensor,
|
215
|
+
encoder_attention_mask: torch.Tensor,
|
216
|
+
encoder_hidden_states: torch.Tensor,
|
217
|
+
past_key_values: torch.Tensor,
|
218
|
+
cache_position: torch.Tensor,
|
219
|
+
attn_impl: str = "eager",
|
220
|
+
):
|
221
|
+
|
222
|
+
# embedding
|
223
|
+
positions_idx = cache_position + self.embed_positions.offset
|
224
|
+
positions = self.embed_positions.weight[positions_idx]
|
225
|
+
|
226
|
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
227
|
+
hidden_states = inputs_embeds + positions
|
228
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
229
|
+
|
230
|
+
# prepare attn_mask
|
231
|
+
input_shape = input_ids.size()
|
232
|
+
if self._use_sdpa:
|
233
|
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
234
|
+
attention_mask, input_shape, inputs_embeds, cache_position
|
235
|
+
)
|
236
|
+
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
237
|
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
238
|
+
)
|
239
|
+
else:
|
240
|
+
attention_mask = _prepare_4d_causal_attention_mask(
|
241
|
+
attention_mask, input_shape, inputs_embeds, cache_position
|
242
|
+
)
|
243
|
+
encoder_attention_mask = _prepare_4d_attention_mask(
|
244
|
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
245
|
+
)
|
246
|
+
|
247
|
+
# iterate decoder_layer
|
248
|
+
next_decoder_cache = ()
|
249
|
+
for idx, decoder_layer in enumerate(self.layers):
|
250
|
+
past_key_value = past_key_values[idx]
|
251
|
+
layer_outputs = _BartDecoderLayer.forward(
|
252
|
+
decoder_layer,
|
253
|
+
hidden_states,
|
254
|
+
attention_mask=attention_mask,
|
255
|
+
encoder_hidden_states=encoder_hidden_states,
|
256
|
+
encoder_attention_mask=encoder_attention_mask,
|
257
|
+
past_key_value=past_key_value,
|
258
|
+
cache_position=cache_position,
|
259
|
+
attn_impl=attn_impl,
|
260
|
+
)
|
261
|
+
hidden_states = layer_outputs[0]
|
262
|
+
next_decoder_cache += (layer_outputs[1],)
|
263
|
+
|
264
|
+
return BaseModelOutputWithPastAndCrossAttentions(
|
265
|
+
last_hidden_state=hidden_states,
|
266
|
+
past_key_values=next_decoder_cache,
|
267
|
+
)
|
268
|
+
|
269
|
+
|
270
|
+
class BartDecoderWrapper(torch.nn.Module):
|
271
|
+
def __init__(self, model: "BartForConditionalGeneration"):
|
272
|
+
super().__init__()
|
273
|
+
self.config = model.config
|
274
|
+
self.decoder = model.get_decoder()
|
275
|
+
self.num_layers = self.config.decoder_layers
|
276
|
+
self.lm_head = model.lm_head
|
277
|
+
|
278
|
+
def forward(
|
279
|
+
self,
|
280
|
+
input_ids: torch.Tensor,
|
281
|
+
attention_mask: torch.Tensor,
|
282
|
+
encoder_attention_mask: torch.Tensor,
|
283
|
+
cache_position: torch.Tensor,
|
284
|
+
self_kv_cache: torch.Tensor,
|
285
|
+
cross_kv_cache: torch.Tensor,
|
286
|
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
287
|
+
|
288
|
+
# prepare past_key_values
|
289
|
+
kv_cache = ()
|
290
|
+
for i in range(0, self.num_layers * 2, 2):
|
291
|
+
kv_cache = kv_cache + (
|
292
|
+
(
|
293
|
+
self_kv_cache[i],
|
294
|
+
self_kv_cache[i + 1],
|
295
|
+
cross_kv_cache[i],
|
296
|
+
cross_kv_cache[i + 1],
|
297
|
+
),
|
298
|
+
)
|
299
|
+
|
300
|
+
# decode
|
301
|
+
decoder_outputs = _BartDecoder.forward(
|
302
|
+
self.decoder,
|
303
|
+
input_ids=input_ids,
|
304
|
+
attention_mask=attention_mask,
|
305
|
+
encoder_attention_mask=encoder_attention_mask,
|
306
|
+
cache_position=cache_position,
|
307
|
+
past_key_values=kv_cache,
|
308
|
+
encoder_hidden_states=torch.tensor([1]),
|
309
|
+
attn_impl=self.config._attn_implementation,
|
310
|
+
)
|
311
|
+
sequence_output = decoder_outputs[0]
|
312
|
+
lm_logits = self.lm_head(sequence_output)
|
313
|
+
|
314
|
+
# get self_kv_cache from ouputs
|
315
|
+
past_key_values = decoder_outputs[1]
|
316
|
+
self_kv_cache = []
|
317
|
+
for i in range(self.num_layers):
|
318
|
+
self_kv_cache.append(past_key_values[i][0])
|
319
|
+
self_kv_cache.append(past_key_values[i][1])
|
320
|
+
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
321
|
+
|
322
|
+
return lm_logits, self_kv_cache
|
323
|
+
|
324
|
+
|
325
|
+
class BartEncoderWrapper(torch.nn.Module):
|
326
|
+
def __init__(self, model):
|
327
|
+
super().__init__()
|
328
|
+
self.model = model
|
329
|
+
self.config = model.config
|
330
|
+
self.decoder = model.get_decoder()
|
331
|
+
self.encoder = model.get_encoder()
|
332
|
+
self.num_layers = self.config.encoder_layers
|
333
|
+
self.decoder_max_length = self.config.max_position_embeddings
|
334
|
+
self.encoder_max_length = self.config.max_position_embeddings
|
335
|
+
self.num_heads = self.config.decoder_attention_heads
|
336
|
+
self.d_kv = self.config.d_model // self.num_heads
|
337
|
+
|
338
|
+
def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> Tuple[torch.Tensor]:
|
339
|
+
encoder_batch_size = input_ids.shape[0]
|
340
|
+
decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
|
341
|
+
|
342
|
+
# 1. run encoder
|
343
|
+
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
344
|
+
last_hidden_states = encoder_outputs[0]
|
345
|
+
|
346
|
+
# 2. run dummy decoder to get pre-calculated cross-key_values for generation
|
347
|
+
dummy_past_key_value = []
|
348
|
+
for _ in range(self.num_layers):
|
349
|
+
pkv_self_attn_key = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
|
350
|
+
pkv_self_attn_value = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
|
351
|
+
pkv_cross_attn_key = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
|
352
|
+
pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
|
353
|
+
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
354
|
+
dummy_past_key_value.append(layer_pkv)
|
355
|
+
|
356
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.int64)
|
357
|
+
decoder_attention_mask[:, :1] = 1
|
358
|
+
|
359
|
+
decoder_outputs = _BartDecoder.forward(
|
360
|
+
self.decoder,
|
361
|
+
input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
|
362
|
+
attention_mask=decoder_attention_mask,
|
363
|
+
encoder_attention_mask=attention_mask,
|
364
|
+
cache_position=torch.tensor(0, dtype=torch.int32),
|
365
|
+
encoder_hidden_states=last_hidden_states,
|
366
|
+
past_key_values=dummy_past_key_value,
|
367
|
+
attn_impl=self.config._attn_implementation,
|
368
|
+
)
|
369
|
+
first_past_kv = decoder_outputs[1]
|
370
|
+
|
371
|
+
# 3. return cross_key_values to recurrence port. fyi (enc_ir.outputs[0] -> dec_ir.inputs[5])
|
372
|
+
encoder_kv = []
|
373
|
+
for layer_out in first_past_kv: # for layer
|
374
|
+
encoder_kv.append(torch.stack(layer_out[2:], dim=0))
|
375
|
+
encoder_kv = torch.stack(encoder_kv, dim=0)
|
376
|
+
|
377
|
+
return encoder_kv
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|