optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,439 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import TYPE_CHECKING, Optional, Tuple
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
29
+ from transformers.models.t5.configuration_t5 import T5Config
30
+ from transformers.models.t5.modeling_t5 import (
31
+ T5Attention,
32
+ T5Block,
33
+ T5LayerCrossAttention,
34
+ T5LayerSelfAttention,
35
+ T5Stack,
36
+ )
37
+ from transformers.utils import logging
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ if TYPE_CHECKING:
43
+ from transformers import T5ForConditionalGeneration
44
+
45
+
46
+ class T5Encoder(T5Stack):
47
+ def forward(
48
+ self,
49
+ input_ids: torch.Tensor,
50
+ attention_mask: torch.Tensor,
51
+ position_bias: torch.Tensor,
52
+ ) -> BaseModelOutput:
53
+ hidden_states = self.embed_tokens(input_ids)
54
+ extended_attention_mask = self.invert_attention_mask(attention_mask)
55
+ position_bias = position_bias + extended_attention_mask
56
+ for i, layer_module in enumerate(self.block):
57
+ layer_outputs = _T5Block.forward(
58
+ layer_module,
59
+ hidden_states,
60
+ position_bias=position_bias,
61
+ )
62
+ hidden_states = layer_outputs[0]
63
+ hidden_states = self.final_layer_norm(hidden_states)
64
+ return BaseModelOutput(last_hidden_state=hidden_states)
65
+
66
+
67
+ class T5Decoder(T5Stack):
68
+ def forward(
69
+ self,
70
+ input_ids: torch.Tensor,
71
+ attention_mask: torch.Tensor,
72
+ encoder_hidden_states: torch.Tensor,
73
+ encoder_attention_mask: torch.Tensor,
74
+ past_key_values: torch.Tensor,
75
+ position_bias: torch.Tensor,
76
+ encoder_decoder_position_bias: torch.Tensor,
77
+ cache_position: torch.Tensor,
78
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
79
+ hidden_states = self.embed_tokens(input_ids)
80
+ extended_attention_mask = self.invert_attention_mask(attention_mask)
81
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
82
+
83
+ position_bias = position_bias + extended_attention_mask
84
+ encoder_decoder_position_bias = encoder_decoder_position_bias + encoder_extended_attention_mask
85
+
86
+ present_key_value_states = ()
87
+ for layer_module, past_key_value in zip(self.block, past_key_values):
88
+ layer_outputs = _T5Block.forward(
89
+ layer_module,
90
+ hidden_states,
91
+ position_bias=position_bias,
92
+ encoder_hidden_states=encoder_hidden_states,
93
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
94
+ past_key_value=past_key_value,
95
+ cache_position=cache_position,
96
+ )
97
+ hidden_states, present_key_value_state = layer_outputs[:2]
98
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
99
+
100
+ hidden_states = self.final_layer_norm(hidden_states)
101
+
102
+ return BaseModelOutputWithPastAndCrossAttentions(
103
+ last_hidden_state=hidden_states,
104
+ past_key_values=present_key_value_states,
105
+ )
106
+
107
+
108
+ class T5EncoderWrapper(torch.nn.Module):
109
+ def __init__(self, model: "T5ForConditionalGeneration"):
110
+ super().__init__()
111
+ self.config = model.config
112
+ self.model = model
113
+ self.encoder = model.encoder
114
+ self.decoder = model.decoder
115
+ self.default_max_length = getattr(self.config, "n_positions", None) or getattr(
116
+ self.config, "max_position_embeddings", None
117
+ )
118
+ self.encoder_max_length = None
119
+ self.decoder_max_length = None
120
+ self.decoder_batch_size = 1
121
+
122
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
123
+ encoder_batch_size = input_ids.shape[0]
124
+ decoder_batch_size = self.decoder_batch_size
125
+ decoder_max_length = self.decoder_max_length or self.default_max_length
126
+ encoder_max_length = self.encoder_max_length or self.default_max_length
127
+
128
+ attn_layer = self.encoder.block[0].layer[0].SelfAttention
129
+ encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
130
+ encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias)
131
+
132
+ attn_layer = self.decoder.block[0].layer[0].SelfAttention
133
+ decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
134
+ decoder_position_bias = decoder_position_bias[:, :, :1]
135
+
136
+ attn_layer = self.decoder.block[0].layer[1].EncDecAttention
137
+ encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
138
+
139
+ dummy_past_key_value = []
140
+ for i in range(self.config.num_layers):
141
+ pkv_self_attn_key = torch.zeros(
142
+ decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
143
+ )
144
+ pkv_self_attn_value = torch.zeros(
145
+ decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
146
+ )
147
+ pkv_cross_attn_key = torch.zeros(
148
+ encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
149
+ )
150
+ pkv_cross_attn_value = torch.zeros(
151
+ encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
152
+ )
153
+ layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
154
+ dummy_past_key_value.append(layer_pkv)
155
+
156
+ decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.int64)
157
+ decoder_attention_mask[:, :1] = 1
158
+
159
+ # Since first step of decoder has different graph to further step of it,
160
+ # here we merges decoder into its corresponding encoder.
161
+ # TODO(jongho): Separate first-step-decoder.
162
+ decoder_outputs = T5Decoder.forward(
163
+ self.decoder,
164
+ input_ids=torch.zeros(decoder_batch_size, 1, dtype=torch.int64),
165
+ attention_mask=decoder_attention_mask,
166
+ position_bias=decoder_position_bias,
167
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
168
+ encoder_hidden_states=encoder_outputs.last_hidden_state,
169
+ encoder_attention_mask=attention_mask,
170
+ past_key_values=dummy_past_key_value,
171
+ cache_position=torch.tensor(0, dtype=torch.int32),
172
+ )
173
+
174
+ past_key_values = decoder_outputs.past_key_values
175
+
176
+ cross_kv_cache = []
177
+ for i in range(self.model.config.num_layers):
178
+ cross_kv_cache.append(past_key_values[i][2])
179
+ cross_kv_cache.append(past_key_values[i][3])
180
+ cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
181
+
182
+ return cross_kv_cache
183
+
184
+
185
+ class T5DecoderWrapper(torch.nn.Module):
186
+ def __init__(self, model: "T5ForConditionalGeneration"):
187
+ super().__init__()
188
+ self.config = model.config
189
+ self.model = model
190
+ self.encoder = model.encoder
191
+ self.decoder = model.decoder
192
+ self.default_max_length = getattr(self.config, "n_positions", None) or getattr(
193
+ self.config, "max_position_embeddings", None
194
+ )
195
+ self.encoder_max_length = None
196
+ self.decoder_max_length = None
197
+
198
+ def forward(
199
+ self,
200
+ input_ids: torch.Tensor,
201
+ attention_mask: torch.Tensor,
202
+ encoder_attention_mask: torch.Tensor,
203
+ cache_position: torch.Tensor,
204
+ self_kv_cache: torch.Tensor,
205
+ cross_kv_cache: torch.Tensor,
206
+ ) -> Tuple[torch.Tensor]:
207
+ # cache_position : step 0부터
208
+ # attention_mask : 1개가 색칠된것부터 ([0:cache_position+1])
209
+ num_layers = self.model.config.num_layers
210
+ encoder_max_length = self.encoder_max_length or self.default_max_length
211
+ decoder_max_length = self.decoder_max_length or self.default_max_length
212
+
213
+ kv_cache = ()
214
+ for i in range(0, num_layers * 2, 2):
215
+ kv_cache = kv_cache + (
216
+ (
217
+ self_kv_cache[i],
218
+ self_kv_cache[i + 1],
219
+ cross_kv_cache[i],
220
+ cross_kv_cache[i + 1],
221
+ ),
222
+ )
223
+
224
+ attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
225
+ _decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
226
+ decoder_position_bias = _decoder_position_bias[:, :, cache_position].unsqueeze(2)
227
+
228
+ attn_layer = self.model.decoder.block[0].layer[1].EncDecAttention
229
+ encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
230
+
231
+ decoder_outputs = T5Decoder.forward(
232
+ self.model.decoder,
233
+ input_ids=input_ids,
234
+ attention_mask=attention_mask,
235
+ encoder_hidden_states=1,
236
+ encoder_attention_mask=encoder_attention_mask,
237
+ position_bias=decoder_position_bias,
238
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
239
+ past_key_values=kv_cache,
240
+ cache_position=cache_position,
241
+ )
242
+
243
+ past_key_values = decoder_outputs.past_key_values
244
+ sequence_output = decoder_outputs[0]
245
+ if self.model.config.tie_word_embeddings:
246
+ # Rescale output before projecting on vocab
247
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
248
+ sequence_output = sequence_output * (self.model.model_dim**-0.5)
249
+ lm_logits = self.model.lm_head(sequence_output)
250
+
251
+ self_kv_cache = []
252
+ for i in range(self.model.config.num_layers):
253
+ self_kv_cache.append(past_key_values[i][0])
254
+ self_kv_cache.append(past_key_values[i][1])
255
+
256
+ self_kv_cache = torch.stack(self_kv_cache, dim=0)
257
+
258
+ return lm_logits, self_kv_cache
259
+
260
+
261
+ class _T5Attention(T5Attention):
262
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
263
+ super().__init__(config, has_relative_attention_bias)
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ key_value_states: Tuple[torch.Tensor] = None,
269
+ position_bias: torch.Tensor = None,
270
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
271
+ cache_position: Optional[torch.Tensor] = None, # 현재 cache sequence 길이
272
+ is_self_attn: Optional[bool] = None,
273
+ ) -> Tuple[torch.Tensor]:
274
+ batch_size = hidden_states.shape[0]
275
+ cross_batch_size = key_value_states.shape[0] if not is_self_attn and cache_position == 0 else None
276
+
277
+ def shape(states, batch_size):
278
+ """projection"""
279
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
280
+
281
+ def unshape(states, batch_size):
282
+ """reshape"""
283
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
284
+
285
+ query_states = shape(self.q(hidden_states), batch_size) # (batch_size, n_heads, seq_length, dim_per_head)
286
+
287
+ # projection
288
+ if is_self_attn:
289
+ key_states = shape(self.k(hidden_states), batch_size)
290
+ value_states = shape(self.v(hidden_states), batch_size)
291
+ if past_key_value is not None:
292
+ # decoder self attn
293
+ cache_k = past_key_value[0].slice_scatter(
294
+ key_states, dim=2, start=cache_position, end=cache_position + 1
295
+ )
296
+ cache_v = past_key_value[1].slice_scatter(
297
+ value_states, dim=2, start=cache_position, end=cache_position + 1
298
+ )
299
+ past_key_value = (cache_k, cache_v)
300
+ key_states, value_states = past_key_value
301
+
302
+ else:
303
+ # cross-attn
304
+ if cache_position == 0:
305
+ key_states = shape(self.k(key_value_states), cross_batch_size)
306
+ value_states = shape(self.v(key_value_states), cross_batch_size)
307
+ past_key_value = key_states, value_states
308
+ else:
309
+ key_states = past_key_value[0]
310
+ value_states = past_key_value[1]
311
+
312
+ # compute scores
313
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
314
+ scores += position_bias
315
+
316
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
317
+ scores
318
+ ) # (batch_size, n_heads, seq_length, key_length)
319
+
320
+ attn_output = unshape(torch.matmul(attn_weights, value_states), batch_size) # (batch_size, seq_length, dim)
321
+ attn_output = self.o(attn_output)
322
+
323
+ outputs = (attn_output,) + (past_key_value,)
324
+ return outputs
325
+
326
+
327
+ class _T5LayerSelfAttention(T5LayerSelfAttention):
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ position_bias: torch.Tensor = None,
332
+ past_key_value: Tuple[torch.Tensor] = None,
333
+ cache_position: Optional[torch.Tensor] = None,
334
+ ):
335
+ normed_hidden_states = self.layer_norm(hidden_states)
336
+ attention_output = _T5Attention.forward(
337
+ self.SelfAttention,
338
+ hidden_states=normed_hidden_states,
339
+ position_bias=position_bias,
340
+ past_key_value=past_key_value,
341
+ cache_position=cache_position,
342
+ is_self_attn=True,
343
+ )
344
+
345
+ # Residual Connection
346
+ hidden_states = hidden_states + self.dropout(attention_output[0])
347
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
348
+ return outputs
349
+
350
+
351
+ class _T5LayerCrossAttention(T5LayerCrossAttention):
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ key_value_states: torch.Tensor,
356
+ position_bias: torch.Tensor = None,
357
+ past_key_value: Tuple[torch.Tensor] = None,
358
+ cache_position: Optional[torch.Tensor] = None,
359
+ ):
360
+ normed_hidden_states = self.layer_norm(hidden_states)
361
+ attention_output = _T5Attention.forward(
362
+ self.EncDecAttention,
363
+ hidden_states=normed_hidden_states,
364
+ key_value_states=key_value_states,
365
+ position_bias=position_bias,
366
+ past_key_value=past_key_value,
367
+ cache_position=cache_position,
368
+ is_self_attn=False,
369
+ )
370
+
371
+ # Residual connection
372
+ layer_output = hidden_states + self.dropout(attention_output[0])
373
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
374
+ return outputs
375
+
376
+
377
+ class _T5Block(T5Block):
378
+ def forward(
379
+ self,
380
+ hidden_states,
381
+ position_bias=None,
382
+ encoder_hidden_states=None,
383
+ encoder_decoder_position_bias=None,
384
+ past_key_value=None,
385
+ cache_position=None,
386
+ ):
387
+ if past_key_value is not None:
388
+ if not self.is_decoder:
389
+ logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
390
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
391
+
392
+ if len(past_key_value) != expected_num_past_key_values:
393
+ raise ValueError(
394
+ f"There should be {expected_num_past_key_values} past states. "
395
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
396
+ f"Got {len(past_key_value)} past key / value states"
397
+ )
398
+
399
+ self_attn_past_key_value = past_key_value[:2]
400
+ if self_attn_past_key_value == (None, None):
401
+ self_attn_past_key_value = None
402
+
403
+ cross_attn_past_key_value = past_key_value[2:]
404
+ else:
405
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
406
+
407
+ self_attention_outputs = _T5LayerSelfAttention.forward(
408
+ self.layer[0],
409
+ hidden_states=hidden_states,
410
+ position_bias=position_bias,
411
+ past_key_value=self_attn_past_key_value,
412
+ cache_position=cache_position,
413
+ )
414
+
415
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
416
+
417
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
418
+ if do_cross_attention:
419
+ cross_attention_outputs = _T5LayerCrossAttention.forward(
420
+ self.layer[1],
421
+ hidden_states,
422
+ key_value_states=encoder_hidden_states,
423
+ position_bias=encoder_decoder_position_bias,
424
+ past_key_value=cross_attn_past_key_value,
425
+ cache_position=cache_position,
426
+ )
427
+ hidden_states = cross_attention_outputs[0]
428
+ # Combine self attn and cross attn key value states
429
+ if present_key_value_state is not None:
430
+ # print(present_key_value_state.shape)
431
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
432
+
433
+ # Apply Feed Forward layer
434
+ hidden_states = self.layer[-1](hidden_states)
435
+
436
+ outputs = (hidden_states,)
437
+ outputs = outputs + (present_key_value_state,)
438
+
439
+ return outputs
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
@@ -0,0 +1,121 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING, Optional, Union
26
+
27
+ import torch
28
+ from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
29
+ from transformers.modeling_outputs import CausalLMOutput
30
+
31
+ from ....modeling_base import RBLNModel
32
+ from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import (
39
+ AutoFeatureExtractor,
40
+ AutoProcessor,
41
+ AutoTokenizer,
42
+ PretrainedConfig,
43
+ )
44
+
45
+
46
+ class _Wav2Vec2(torch.nn.Module):
47
+ def __init__(self, model: "Wav2Vec2ForCTC"):
48
+ super().__init__()
49
+ self.model = model
50
+
51
+ def forward(self, input_values):
52
+ output = self.model.wav2vec2(input_values=input_values)
53
+ return self.model.lm_head(output[0])
54
+
55
+
56
+ class RBLNWav2Vec2ForCTC(RBLNModel):
57
+ """
58
+ Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
59
+
60
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the
61
+ library implements for all its model.
62
+
63
+ It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
64
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
65
+ - compiling the resulting graph using the RBLN compiler.
66
+ """
67
+
68
+ model_type = "rbln_model"
69
+ main_input_name = "input_values"
70
+ auto_model_class = AutoModelForMaskedLM
71
+
72
+ @classmethod
73
+ def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
74
+ return _Wav2Vec2(model).eval()
75
+
76
+ @classmethod
77
+ def _get_rbln_config(
78
+ cls,
79
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
80
+ model_config: "PretrainedConfig",
81
+ rbln_max_seq_len: Optional[int] = None,
82
+ rbln_batch_size: Optional[int] = None,
83
+ ) -> RBLNConfig:
84
+ meta = {}
85
+
86
+ if rbln_max_seq_len is None:
87
+ for tokenizer in preprocessors:
88
+ if hasattr(tokenizer, "model_max_length"):
89
+ rbln_max_seq_len = tokenizer.model_max_length
90
+ break
91
+ if rbln_max_seq_len is None:
92
+ raise ValueError("`rbln_max_seq_len` should be specified!")
93
+
94
+ meta["rbln_max_seq_len"] = rbln_max_seq_len
95
+
96
+ if rbln_batch_size is None:
97
+ rbln_batch_size = 1
98
+
99
+ input_info = [
100
+ (
101
+ "input_values",
102
+ [
103
+ rbln_batch_size,
104
+ rbln_max_seq_len,
105
+ ],
106
+ "float32",
107
+ ),
108
+ ]
109
+
110
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info, batch_size=rbln_batch_size)
111
+
112
+ rbln_config = RBLNConfig.from_rbln_runtime_configs(
113
+ [rbln_runtime_config],
114
+ _rbln_meta=meta,
115
+ )
116
+
117
+ return rbln_config
118
+
119
+ def forward(self, input_values: "torch.Tensor", **kwargs):
120
+ outputs = super().forward(input_values, **kwargs)
121
+ return CausalLMOutput(logits=outputs)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_whisper import RBLNWhisperForConditionalGeneration