optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,116 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING, Optional, Union
26
+
27
+ import torch
28
+ from transformers import AutoConfig, AutoModel, CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection
29
+ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
30
+
31
+ from ....modeling_base import RBLNModel
32
+ from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel
39
+
40
+
41
+ class _TextEncoder(torch.nn.Module):
42
+ def __init__(self, enc: "CLIPTextModel"):
43
+ super().__init__()
44
+ enc.config.return_dict = False
45
+ enc.config.output_hidden_states = True
46
+ self.enc = enc
47
+
48
+ def forward(self, inp):
49
+ enc_out = self.enc(inp)
50
+ return enc_out
51
+
52
+
53
+ class RBLNCLIPTextModel(RBLNModel):
54
+ model_type = "rbln_clip"
55
+ auto_model_class = AutoModel # feature extraction
56
+ original_model_class = CLIPTextModel
57
+ original_config_class = CLIPTextConfig
58
+
59
+ def __post_init__(self, **kwargs):
60
+ self.dtype = torch.float32
61
+
62
+ @classmethod
63
+ def from_pretrained(cls, *args, **kwargs):
64
+ configtmp = AutoConfig.from_pretrained
65
+ modeltmp = AutoModel.from_pretrained
66
+ AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
67
+ AutoModel.from_pretrained = cls.original_model_class.from_pretrained
68
+ rt = super().from_pretrained(*args, **kwargs)
69
+ AutoConfig.from_pretrained = configtmp
70
+ AutoModel.from_pretrained = modeltmp
71
+ return rt
72
+
73
+ @classmethod
74
+ def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
75
+ return _TextEncoder(model).eval()
76
+
77
+ @classmethod
78
+ def _get_rbln_config(
79
+ cls,
80
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
81
+ model_config: "CLIPTextConfig",
82
+ rbln_batch_size: Optional[int] = None,
83
+ rbln_img_width: Optional[int] = None,
84
+ rbln_img_height: Optional[int] = None,
85
+ ) -> RBLNConfig:
86
+ model_config.return_dict = False
87
+ if rbln_batch_size is None:
88
+ rbln_batch_size = 1
89
+
90
+ rbln_runtime_config = RBLNRuntimeConfig(
91
+ input_info=[
92
+ (
93
+ "input_ids",
94
+ [
95
+ rbln_batch_size,
96
+ model_config.max_position_embeddings,
97
+ ],
98
+ "int64",
99
+ ),
100
+ ],
101
+ )
102
+
103
+ rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
104
+ return rbln_config
105
+
106
+ def forward(self, input_ids: "torch.Tensor", **kwargs):
107
+ text_output = super().forward(input_ids)
108
+ return CLIPTextModelOutput(
109
+ text_embeds=text_output[0],
110
+ last_hidden_state=text_output[1],
111
+ hidden_states=text_output[2:],
112
+ )
113
+
114
+
115
+ class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
116
+ original_model_class = CLIPTextModelWithProjection
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_gpt2 import RBLNGPT2LMHeadModel
@@ -0,0 +1,253 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPast,
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ )
32
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
33
+
34
+
35
+ class _GPT2Attention(GPT2Attention):
36
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
37
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
38
+
39
+ if self.scale_attn_weights:
40
+ attn_weights = attn_weights / torch.full(
41
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
42
+ )
43
+
44
+ # Layer-wise attention scaling
45
+ if self.scale_attn_by_inverse_layer_idx:
46
+ attn_weights = attn_weights / float(self.layer_idx + 1)
47
+
48
+ # -------------------
49
+ # Below are deleted since "where" op does not supported on RBLN graph.
50
+ # -------------------
51
+ # if not self.is_cross_attention:
52
+ # # if only "normal" attention layer implements causal mask
53
+ # query_length, key_length = query.size(-2), key.size(-2)
54
+ # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
55
+ # mask_value = torch.finfo(attn_weights.dtype).min
56
+ # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
57
+ # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
58
+ # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
59
+ # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
60
+
61
+ if attention_mask is not None:
62
+ # Apply the attention mask
63
+ attn_weights = attn_weights + attention_mask
64
+
65
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
66
+
67
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
68
+ attn_weights = attn_weights.type(value.dtype)
69
+ # attn_weights = self.attn_dropout(attn_weights)
70
+
71
+ # Mask heads if we want to
72
+ if head_mask is not None:
73
+ attn_weights = attn_weights * head_mask
74
+
75
+ attn_output = torch.matmul(attn_weights, value)
76
+
77
+ return attn_output, attn_weights
78
+
79
+ def forward(
80
+ self,
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
82
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
83
+ attention_mask: Optional[torch.FloatTensor] = None,
84
+ head_mask: Optional[torch.FloatTensor] = None,
85
+ encoder_hidden_states: Optional[torch.Tensor] = None,
86
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
87
+ use_cache: Optional[bool] = False,
88
+ output_attentions: Optional[bool] = False,
89
+ cache_position: Optional[torch.LongTensor] = None,
90
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
91
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
92
+
93
+ query = self._split_heads(query, self.num_heads, self.head_dim)
94
+ key = self._split_heads(key, self.num_heads, self.head_dim)
95
+ value = self._split_heads(value, self.num_heads, self.head_dim)
96
+
97
+ if layer_past is not None:
98
+ past_key, past_value = layer_past
99
+ query_length = query.shape[-2]
100
+
101
+ key = torch.slice_scatter(past_key, key, dim=2, start=cache_position, end=cache_position + query_length)
102
+ value = torch.slice_scatter(
103
+ past_value, value, dim=2, start=cache_position, end=cache_position + query_length
104
+ )
105
+
106
+ present = (key, value)
107
+ attn_output, _ = _GPT2Attention._attn(self, query, key, value, attention_mask, head_mask)
108
+
109
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
110
+
111
+ attn_output = self.c_proj(attn_output)
112
+ attn_output = self.resid_dropout(attn_output)
113
+
114
+ outputs = (attn_output, present)
115
+
116
+ return outputs
117
+
118
+
119
+ class _GPT2Block(GPT2Block):
120
+ def forward(
121
+ self,
122
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
123
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
124
+ attention_mask: Optional[torch.FloatTensor] = None,
125
+ head_mask: Optional[torch.FloatTensor] = None,
126
+ encoder_hidden_states: Optional[torch.Tensor] = None,
127
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
128
+ use_cache: Optional[bool] = False,
129
+ output_attentions: Optional[bool] = False,
130
+ cache_position: Optional[torch.LongTensor] = None,
131
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
132
+ residual = hidden_states
133
+ hidden_states = self.ln_1(hidden_states)
134
+
135
+ attn_outputs = _GPT2Attention.forward(
136
+ self.attn,
137
+ hidden_states,
138
+ layer_past=layer_past,
139
+ attention_mask=attention_mask,
140
+ head_mask=head_mask,
141
+ cache_position=cache_position,
142
+ )
143
+
144
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
145
+ outputs = attn_outputs[1:]
146
+ # residual connection
147
+ hidden_states = attn_output + residual
148
+
149
+ residual = hidden_states
150
+ hidden_states = self.ln_2(hidden_states)
151
+ feed_forward_hidden_states = self.mlp(hidden_states)
152
+ # residual connection
153
+ hidden_states = residual + feed_forward_hidden_states
154
+
155
+ outputs = (hidden_states,) + outputs
156
+ return outputs # hidden_states, present, (attentions, cross_attentions)
157
+
158
+
159
+ class _GPT2Model(GPT2Model):
160
+ def forward(
161
+ self,
162
+ input_ids: Optional[torch.LongTensor] = None,
163
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
164
+ attention_mask: Optional[torch.FloatTensor] = None,
165
+ position_ids: Optional[torch.LongTensor] = None,
166
+ head_mask: Optional[torch.FloatTensor] = None,
167
+ cache_position: Optional[torch.LongTensor] = None,
168
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
169
+ input_shape = input_ids.size()
170
+
171
+ if position_ids is None:
172
+ # force dtype to torch.long -> torch.int32 (to match cache_position)
173
+ position_ids = torch.arange(0, input_shape[-1], dtype=torch.int32) + cache_position
174
+ position_ids = position_ids.unsqueeze(0)
175
+
176
+ # GPT2Attention mask.
177
+ # Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
178
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
179
+
180
+ # Prepare head mask if needed
181
+ # 1.0 in head_mask indicate we keep the head
182
+ # attention_probs has shape bsz x n_heads x N x N
183
+ # head_mask has shape n_layer x batch x n_heads x N x N
184
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
185
+
186
+ inputs_embeds = self.wte(input_ids)
187
+ position_embeds = self.wpe(position_ids)
188
+ hidden_states = inputs_embeds + position_embeds
189
+
190
+ hidden_states = self.drop(hidden_states)
191
+
192
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
193
+
194
+ presents = ()
195
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
196
+ outputs = _GPT2Block.forward(
197
+ block,
198
+ hidden_states,
199
+ layer_past=layer_past,
200
+ attention_mask=attention_mask,
201
+ head_mask=head_mask[i],
202
+ cache_position=cache_position,
203
+ )
204
+ hidden_states = outputs[0]
205
+
206
+ presents = presents + (outputs[1],)
207
+
208
+ hidden_states = self.ln_f(hidden_states)
209
+ hidden_states = hidden_states.view(output_shape)
210
+ return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=presents)
211
+
212
+
213
+ class GPT2LMHeadModelWrapper(torch.nn.Module):
214
+ def __init__(self, gpt):
215
+ super().__init__()
216
+ self.model = gpt
217
+
218
+ def forward(
219
+ self,
220
+ input_ids: torch.Tensor,
221
+ past_key_values: torch.Tensor,
222
+ attention_mask: torch.Tensor,
223
+ cache_position: torch.LongTensor,
224
+ ):
225
+ kv_cache = []
226
+ for i in range(self.model.config.n_layer):
227
+ kv_cache.append((past_key_values[i, 0], past_key_values[i, 1]))
228
+
229
+ transformer_outputs = _GPT2Model.forward(
230
+ self.model.transformer,
231
+ input_ids=input_ids,
232
+ past_key_values=kv_cache,
233
+ attention_mask=attention_mask,
234
+ cache_position=cache_position,
235
+ )
236
+
237
+ hidden_states = transformer_outputs[0]
238
+
239
+ # TODO : Use query_length here to pick last logit
240
+ # batch_size, sequence_length = hidden_states.shape[:2]
241
+ # hidden_states = hidden_states.view(batch_size * sequence_length, -1)
242
+ # hidden_states = torch.nn.functional.embedding(query_length, hidden_states)
243
+ # hidden_states = hidden_states.view(batch_size, 1, -1)
244
+
245
+ lm_logits = self.model.lm_head(hidden_states)
246
+ kv_cache = transformer_outputs[1]
247
+
248
+ past_key_values = []
249
+ for i in range(self.model.config.n_layer):
250
+ past_key_values.append(torch.stack(kv_cache[i], dim=0))
251
+ past_key_values = torch.stack(past_key_values, dim=0)
252
+
253
+ return lm_logits, past_key_values