optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,406 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.modeling_attn_mask_utils import (
29
+ _prepare_4d_causal_attention_mask,
30
+ _prepare_4d_causal_attention_mask_for_sdpa,
31
+ )
32
+ from transformers.modeling_outputs import (
33
+ BaseModelOutput,
34
+ BaseModelOutputWithPastAndCrossAttentions,
35
+ Seq2SeqLMOutput,
36
+ )
37
+ from transformers.models.whisper.modeling_whisper import (
38
+ WhisperAttention,
39
+ WhisperDecoder,
40
+ WhisperDecoderLayer,
41
+ WhisperPositionalEmbedding,
42
+ WhisperSdpaAttention,
43
+ )
44
+ from transformers.utils import logging
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ class _WhisperAttention(WhisperAttention):
51
+ def forward(
52
+ self,
53
+ hidden_states: torch.Tensor,
54
+ key_value_states: Optional[torch.Tensor] = None,
55
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
56
+ attention_mask: Optional[torch.Tensor] = None,
57
+ cache_position: Optional[torch.Tensor] = None,
58
+ **kwargs,
59
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
60
+
61
+ bsz, tgt_len, _ = hidden_states.size()
62
+ is_cross_attention = key_value_states is not None
63
+
64
+ query_states = self.q_proj(hidden_states) * self.scaling
65
+
66
+ if is_cross_attention:
67
+ is_dummy_decoder = len(key_value_states.shape) > 1
68
+ if is_dummy_decoder:
69
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
70
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
71
+ else:
72
+ key_states = past_key_value[0]
73
+ value_states = past_key_value[1]
74
+ else:
75
+ if self.is_decoder:
76
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
77
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
78
+ key_states = past_key_value[0].slice_scatter(
79
+ key_states, dim=2, start=cache_position, end=cache_position + 1
80
+ )
81
+ value_states = past_key_value[1].slice_scatter(
82
+ value_states, dim=2, start=cache_position, end=cache_position + 1
83
+ )
84
+ else:
85
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
86
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
87
+
88
+ if self.is_decoder:
89
+ present_key_value = (key_states, value_states)
90
+ else:
91
+ present_key_value = None
92
+
93
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
94
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
95
+ key_states = key_states.reshape(*proj_shape)
96
+ value_states = value_states.reshape(*proj_shape)
97
+
98
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
99
+ src_len = key_states.size(1)
100
+ if attention_mask is not None:
101
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
102
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
103
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
104
+
105
+ attn_output = torch.bmm(attn_weights, value_states)
106
+
107
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
108
+ attn_output = attn_output.transpose(1, 2)
109
+
110
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
111
+ attn_output = self.out_proj(attn_output)
112
+
113
+ return attn_output, None, present_key_value
114
+
115
+
116
+ class _WhisperSdpaAttention(WhisperSdpaAttention):
117
+ def forward(
118
+ self,
119
+ hidden_states: torch.Tensor,
120
+ key_value_states: Optional[torch.Tensor] = None,
121
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
122
+ attention_mask: Optional[torch.Tensor] = None,
123
+ cache_position: Optional[torch.Tensor] = None,
124
+ **kwargs,
125
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
126
+
127
+ bsz, tgt_len, _ = hidden_states.size()
128
+
129
+ is_cross_attention = key_value_states is not None
130
+
131
+ query_states = self.q_proj(hidden_states)
132
+
133
+ if is_cross_attention:
134
+ is_dummy_decoder = len(key_value_states.shape) > 1
135
+ if is_dummy_decoder:
136
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
137
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
138
+ else:
139
+ key_states = past_key_value[0]
140
+ value_states = past_key_value[1]
141
+ else:
142
+ if self.is_decoder:
143
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
144
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
145
+ key_states = past_key_value[0].slice_scatter(
146
+ key_states, dim=2, start=cache_position, end=cache_position + 1
147
+ )
148
+ value_states = past_key_value[1].slice_scatter(
149
+ value_states, dim=2, start=cache_position, end=cache_position + 1
150
+ )
151
+ else:
152
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
153
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
154
+
155
+ if self.is_decoder:
156
+ present_key_value = (key_states, value_states)
157
+ else:
158
+ present_key_value = None
159
+
160
+ query_states = self._shape(query_states, tgt_len, bsz)
161
+
162
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
163
+ query_states,
164
+ key_states,
165
+ value_states,
166
+ attn_mask=attention_mask,
167
+ dropout_p=0.0,
168
+ is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
169
+ )
170
+
171
+ attn_output = attn_output.transpose(1, 2)
172
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
173
+
174
+ attn_output = self.out_proj(attn_output)
175
+
176
+ return attn_output, None, present_key_value
177
+
178
+
179
+ ATTN_FORWARD_MAP = {"eager": _WhisperAttention.forward, "sdpa": _WhisperSdpaAttention.forward}
180
+
181
+
182
+ class _WhisperDecoderLayer(WhisperDecoderLayer):
183
+ def forward(
184
+ self,
185
+ hidden_states: torch.Tensor,
186
+ attention_mask: Optional[torch.Tensor] = None,
187
+ encoder_hidden_states: Optional[torch.Tensor] = None,
188
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
189
+ cache_position: Optional[torch.Tensor] = None,
190
+ attn_impl: str = "eager",
191
+ ) -> torch.Tensor:
192
+
193
+ # Self Attention Block
194
+ residual = hidden_states
195
+ hidden_states = self.self_attn_layer_norm(hidden_states)
196
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
197
+
198
+ hidden_states, _, present_key_value = ATTN_FORWARD_MAP[attn_impl](
199
+ self.self_attn,
200
+ hidden_states=hidden_states,
201
+ past_key_value=self_attn_past_key_value,
202
+ attention_mask=attention_mask,
203
+ cache_position=cache_position,
204
+ )
205
+ hidden_states = residual + hidden_states
206
+
207
+ # Cross-Attention Block
208
+ residual = hidden_states
209
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
210
+ cross_attn_past_key_value = past_key_value[2:] if past_key_value is not None else None
211
+
212
+ hidden_states, _, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
213
+ self.encoder_attn,
214
+ hidden_states=hidden_states,
215
+ key_value_states=encoder_hidden_states,
216
+ past_key_value=cross_attn_past_key_value,
217
+ cache_position=cache_position,
218
+ )
219
+ hidden_states = residual + hidden_states
220
+ present_key_value = present_key_value + cross_attn_present_key_value
221
+
222
+ # Fully Connected Block
223
+ residual = hidden_states
224
+ hidden_states = self.final_layer_norm(hidden_states)
225
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
226
+ hidden_states = self.fc2(hidden_states)
227
+ hidden_states = residual + hidden_states
228
+
229
+ return hidden_states, present_key_value
230
+
231
+
232
+ class _WhisperPositionalEmbedding(WhisperPositionalEmbedding):
233
+ def forward(self, input_ids, past_key_values_length=0, position_ids=None):
234
+ if position_ids is None:
235
+ return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
236
+ else:
237
+ return self.weight[position_ids]
238
+
239
+
240
+ class _WhisperDecoder(WhisperDecoder):
241
+ def forward(
242
+ self,
243
+ input_ids: Optional[torch.Tensor] = None,
244
+ attention_mask: Optional[torch.Tensor] = None,
245
+ encoder_hidden_states: Optional[torch.Tensor] = None,
246
+ past_key_values: Optional[torch.Tensor] = None,
247
+ cache_position: Optional[torch.Tensor] = None,
248
+ attn_impl: str = "eager",
249
+ **kwargs,
250
+ ):
251
+
252
+ input_shape = input_ids.size()
253
+ input_ids = input_ids.view(-1, input_shape[-1])
254
+
255
+ # positional embeding
256
+ inputs_embeds = self.embed_tokens(input_ids)
257
+ positions = _WhisperPositionalEmbedding.forward(
258
+ self.embed_positions, input_ids, cache_position, cache_position
259
+ )
260
+ hidden_states = inputs_embeds + positions
261
+
262
+ # prepare casual_attn_mask
263
+ if self._use_sdpa:
264
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
265
+ attention_mask, input_shape, inputs_embeds, cache_position
266
+ )
267
+ else:
268
+ attention_mask = _prepare_4d_causal_attention_mask(
269
+ attention_mask, input_shape, inputs_embeds, cache_position
270
+ )
271
+
272
+ next_decoder_cache = ()
273
+ # iterate decoder_layer
274
+ for idx, decoder_layer in enumerate(self.layers):
275
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
276
+ layer_outputs = _WhisperDecoderLayer.forward(
277
+ decoder_layer,
278
+ hidden_states,
279
+ attention_mask=attention_mask,
280
+ encoder_hidden_states=encoder_hidden_states,
281
+ past_key_value=past_key_value,
282
+ cache_position=cache_position,
283
+ attn_impl=attn_impl,
284
+ )
285
+ hidden_states = layer_outputs[0]
286
+
287
+ next_decoder_cache += (layer_outputs[1],)
288
+
289
+ # layer_norm
290
+ hidden_states = self.layer_norm(hidden_states)
291
+
292
+ return BaseModelOutputWithPastAndCrossAttentions(
293
+ last_hidden_state=hidden_states,
294
+ past_key_values=next_decoder_cache,
295
+ )
296
+
297
+
298
+ class _WhisperDecoderWrapper(torch.nn.Module):
299
+ def __init__(self, model):
300
+ super().__init__()
301
+ self.proj_out = model.proj_out
302
+ self.config = model.config
303
+ self.decoder = model.get_decoder()
304
+ self.num_layers = self.config.decoder_layers
305
+ self.attn_impl = self.config._attn_implementation
306
+
307
+ def forward(
308
+ self,
309
+ decoder_input_ids: torch.Tensor,
310
+ decoder_attention_mask: torch.Tensor,
311
+ cache_position: torch.Tensor,
312
+ self_kv_cache: torch.Tensor,
313
+ cross_kv_cache: torch.Tensor,
314
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
315
+
316
+ # prepare past_key_values
317
+ kv_cache = ()
318
+ for i in range(0, self.num_layers * 2, 2):
319
+ kv_cache = kv_cache + (
320
+ (
321
+ self_kv_cache[i],
322
+ self_kv_cache[i + 1],
323
+ cross_kv_cache[i],
324
+ cross_kv_cache[i + 1],
325
+ ),
326
+ )
327
+
328
+ # Decode
329
+ decoder_outputs = _WhisperDecoder.forward(
330
+ self.decoder,
331
+ input_ids=decoder_input_ids,
332
+ attention_mask=decoder_attention_mask,
333
+ cache_position=cache_position,
334
+ past_key_values=kv_cache,
335
+ encoder_hidden_states=torch.tensor([1]),
336
+ attn_impl=self.attn_impl,
337
+ )
338
+ sequence_output = decoder_outputs[0]
339
+ lm_logits = self.proj_out(sequence_output)
340
+
341
+ # get self_kv_cache from ouputs
342
+ past_key_values = decoder_outputs[1]
343
+ self_kv_cache = []
344
+ for i in range(self.config.decoder_layers):
345
+ self_kv_cache.append(past_key_values[i][0])
346
+ self_kv_cache.append(past_key_values[i][1])
347
+ self_kv_cache = torch.stack(self_kv_cache, dim=0)
348
+
349
+ return lm_logits, self_kv_cache
350
+
351
+
352
+ class _WhisperEncoderWrapper(torch.nn.Module):
353
+ def __init__(self, model):
354
+ super().__init__()
355
+ self.model = model
356
+ self.config = model.config
357
+ self.decoder = model.get_decoder()
358
+ self.encoder = model.get_encoder()
359
+ self.num_layers = self.config.decoder_layers
360
+ self.decoder_max_length = self.config.max_target_positions
361
+ self.encoder_max_length = self.config.max_source_positions
362
+ self.num_heads = self.config.decoder_attention_heads
363
+ self.d_kv = self.config.d_model // self.num_heads
364
+ self.attn_impl = self.config._attn_implementation
365
+
366
+ def forward(
367
+ self,
368
+ input_features: Optional[torch.LongTensor] = None,
369
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
370
+
371
+ encoder_outputs = self.encoder(input_features=input_features)
372
+ last_hidden_states = encoder_outputs[0]
373
+
374
+ encoder_batch_size = input_features.shape[0]
375
+ decoder_batch_size = encoder_batch_size # TODO fix in future
376
+
377
+ dummy_past_key_value = []
378
+ for _ in range(self.num_layers):
379
+ pkv_self_attn_key = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
380
+ pkv_self_attn_value = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
381
+ pkv_cross_attn_key = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
382
+ pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
383
+ layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
384
+ dummy_past_key_value.append(layer_pkv)
385
+
386
+ decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.int64)
387
+ decoder_attention_mask[:, :1] = 1
388
+
389
+ decoder_outputs = _WhisperDecoder.forward(
390
+ self.decoder,
391
+ input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
392
+ attention_mask=decoder_attention_mask,
393
+ cache_position=torch.tensor(0, dtype=torch.int32),
394
+ encoder_hidden_states=last_hidden_states,
395
+ past_key_values=dummy_past_key_value,
396
+ attn_impl=self.attn_impl,
397
+ )
398
+
399
+ first_past_kv = decoder_outputs[1]
400
+
401
+ encoder_kv = []
402
+ for layer_out in first_past_kv: # for layer
403
+ encoder_kv.append(torch.stack(layer_out[2:], dim=0))
404
+ encoder_kv = torch.stack(encoder_kv, dim=0)
405
+
406
+ return encoder_kv
@@ -0,0 +1,25 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .import_utils import is_rbln_available
25
+ from .runtime_utils import RBLNPytorchRuntime
@@ -0,0 +1,28 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import importlib.util
25
+
26
+
27
+ def is_rbln_available() -> bool:
28
+ return importlib.util.find_spec("rebel-compiler") is not None
@@ -0,0 +1,71 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Any, Dict, List
25
+
26
+ import rebel
27
+ import torch
28
+
29
+
30
+ class RBLNPytorchRuntime:
31
+ mandatory_members = []
32
+
33
+ def __init__(self, runtime: rebel.Runtime, **kwargs) -> None:
34
+ self.runtime = runtime
35
+ for key, value in kwargs.items():
36
+ setattr(self, key, value)
37
+ for mandatory_member in __class__.mandatory_members:
38
+ if mandatory_member not in kwargs:
39
+ raise AttributeError(f"`{mandatory_member}` should be assigned to {__class__.__name__} objects.")
40
+
41
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
42
+ return self.forward(*args, **kwds)
43
+
44
+ def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
45
+ args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
46
+ kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor), kwargs.items()))
47
+ output = self.runtime(*args, **kwargs)
48
+ return output
49
+
50
+ def __repr__(self) -> str:
51
+ return repr(self.runtime)
52
+
53
+
54
+ class UnavailableRuntime:
55
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
56
+ raise self.forward(*args, **kwargs)
57
+
58
+ def __len__(self) -> int:
59
+ return 0
60
+
61
+ def __getitem__(self, idx: int) -> Any:
62
+ return self
63
+
64
+ def __iter__(self):
65
+ return iter([self])
66
+
67
+ def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
68
+ raise RuntimeError("RBLN-Runtime is not created, So it is not available.")
69
+
70
+ def __repr__(self) -> str:
71
+ return "UnavailableRuntime"
@@ -0,0 +1,92 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from pathlib import Path
26
+ from typing import List, Union
27
+
28
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def maybe_load_preprocessors(
35
+ src_name_or_path: Union[str, Path], subfolder: str = "", trust_remote_code: bool = False
36
+ ) -> List:
37
+ preprocessors = []
38
+ try:
39
+ preprocessors.append(
40
+ AutoTokenizer.from_pretrained(src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code)
41
+ )
42
+ except Exception:
43
+ pass
44
+
45
+ try:
46
+ preprocessors.append(
47
+ AutoProcessor.from_pretrained(src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code)
48
+ )
49
+ except Exception:
50
+ pass
51
+
52
+ try:
53
+ preprocessors.append(
54
+ AutoFeatureExtractor.from_pretrained(
55
+ src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
56
+ )
57
+ )
58
+ except Exception:
59
+ pass
60
+ return preprocessors
61
+
62
+
63
+ def maybe_save_preprocessors(
64
+ src_name_or_path: Union[str, Path],
65
+ dest_dir: Union[str, Path],
66
+ src_subfolder: str = "",
67
+ trust_remote_code: bool = False,
68
+ ):
69
+ """
70
+ Saves the tokenizer, the processor and the feature extractor when found in `src_dir` in `dest_dir`.
71
+
72
+ Args:
73
+ src_dir (`Union[str, Path]`):
74
+ The source directory from which to copy the files.
75
+ dest_dir (`Union[str, Path]`):
76
+ The destination directory to copy the files to.
77
+ src_subfolder (`str`, defaults to `""`):
78
+ In case the preprocessor files are located inside a subfolder of the model directory / repo on the Hugging
79
+ Face Hub, you can specify the subfolder name here.
80
+ trust_remote_code (`bool`, defaults to `False`):
81
+ Whether to allow to save preprocessors that is allowed to run arbitrary code. Use this option at your own risk.
82
+ """
83
+ if not isinstance(dest_dir, Path):
84
+ dest_dir = Path(dest_dir)
85
+
86
+ dest_dir.mkdir(exist_ok=True)
87
+ preprocessors = maybe_load_preprocessors(
88
+ src_name_or_path, subfolder=src_subfolder, trust_remote_code=trust_remote_code
89
+ )
90
+ for preprocessor in preprocessors:
91
+ preprocessor.save_pretrained(dest_dir)
92
+ return preprocessors