optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. optimum/rbln/__init__.py +17 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
  5. optimum/rbln/diffusers/models/controlnet.py +7 -3
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/modeling_alias.py +19 -1
  13. optimum/rbln/modeling_base.py +162 -18
  14. optimum/rbln/transformers/__init__.py +8 -0
  15. optimum/rbln/transformers/cache_utils.py +111 -0
  16. optimum/rbln/transformers/generation/utils.py +0 -2
  17. optimum/rbln/transformers/models/__init__.py +3 -0
  18. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  19. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  20. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  21. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
  22. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
  23. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  24. optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
  25. optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
  26. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  27. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
  28. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  29. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
  30. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  31. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  32. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  33. optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
  34. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  35. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  36. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  37. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  38. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  39. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  40. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
  41. optimum/rbln/transformers/utils/__init__.py +0 -0
  42. optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
  43. optimum/rbln/utils/import_utils.py +1 -4
  44. optimum/rbln/utils/runtime_utils.py +2 -1
  45. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
  46. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
  47. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  48. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
  49. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -23,17 +23,10 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable
27
27
 
28
- import rebel
29
- import torch
30
- from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
31
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
32
-
33
- from ....modeling_base import RBLNModel
34
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
35
- from ....utils.runtime_utils import RBLNPytorchRuntime
36
- from ...generation.utils import RBLNGenerationMixin
28
+ from ....modeling_config import RBLNConfig
29
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
37
30
  from .hf_hub_cached.modeling_midm import MidmLMHeadModel
38
31
  from .midm_architecture import (
39
32
  MidmLMHeadModelWrapper,
@@ -41,41 +34,18 @@ from .midm_architecture import (
41
34
 
42
35
 
43
36
  logger = logging.getLogger(__name__)
44
-
45
37
  if TYPE_CHECKING:
46
38
  from transformers import (
47
- AutoFeatureExtractor,
48
- AutoProcessor,
49
- AutoTokenizer,
50
- PretrainedConfig,
39
+ PreTrainedModel,
51
40
  )
52
41
 
53
42
 
54
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
55
- mandatory_members = ["main_input_name"]
56
-
57
- # RBLN_Runtimemodule
58
- def forward(
59
- self,
60
- input_ids: torch.LongTensor = None,
61
- attention_mask: torch.LongTensor = None,
62
- cache_position: torch.Tensor = None,
63
- **kwargs: Dict[str, Any],
64
- ):
65
- logits = super().forward(
66
- input_ids=input_ids,
67
- attention_mask=attention_mask,
68
- cache_position=cache_position,
69
- )
70
- return logits
71
-
72
-
73
- class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
43
+ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
74
44
  """
75
45
  The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
76
46
  embeddings).
77
47
 
78
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
48
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
79
49
  library implements for all its model.
80
50
 
81
51
  It implements the methods to convert a pre-trained transformers Midm model into a RBLN transformer model by:
@@ -84,46 +54,10 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
84
54
 
85
55
  """
86
56
 
87
- model_type = "rbln_model"
88
- auto_model_class = AutoModelForCausalLM
89
- main_input_name = "input_ids"
90
-
91
- def __init__(
92
- self,
93
- models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
94
- config: PretrainedConfig = None,
95
- preprocessors: Optional[List] = None,
96
- rbln_config: Optional[RBLNConfig] = None,
97
- rbln_device: Optional[List[int]] = None,
98
- rbln_device_map: Optional[Dict[str, int]] = None,
99
- **kwargs,
100
- ):
101
- super().__init__(
102
- models,
103
- config,
104
- preprocessors,
105
- rbln_config,
106
- rbln_device=rbln_device,
107
- rbln_device_map=rbln_device_map,
108
- **kwargs,
109
- )
110
- self.batch_size = self.rbln_config.meta["rbln_batch_size"]
111
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
112
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
113
-
114
- self.prefill_attention_mask = torch.zeros(
115
- self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
116
- )
117
- self.causal_mask = 1 - torch.triu(
118
- torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
119
- )
120
-
121
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0], main_input_name="input_ids")
122
- self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
123
- self.past_cached_length = 0
124
-
125
- def can_generate(self):
126
- return True
57
+ @classmethod
58
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
59
+ rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
60
+ return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
127
61
 
128
62
  def __getattr__(self, __name: str) -> Any:
129
63
  """This is the key method to implement RBLN-Midm.
@@ -139,252 +73,3 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
139
73
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
140
74
  return redirect(val)
141
75
  return val
142
-
143
- def _reorder_cache(self, past_key_values, beam_idx):
144
- # TODO(jongho): implement
145
- raise NotImplementedError
146
-
147
- @classmethod
148
- @torch.inference_mode()
149
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
150
- wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
151
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
152
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
153
-
154
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
155
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
156
-
157
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
158
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
159
-
160
- prefill_ir = rebel.torchscript_to_ir(
161
- prefill_scripted_model,
162
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
163
- )
164
- dec_ir = rebel.torchscript_to_ir(
165
- dec_scripted_model,
166
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
167
- )
168
-
169
- connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
170
-
171
- compiled_model = rebel.compile(
172
- prefill_ir,
173
- dec_ir,
174
- connections=connections,
175
- fusion=prefill_rbln_runtime_config.fusion,
176
- npu=prefill_rbln_runtime_config.npu,
177
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
178
- use_weight_sharing=True,
179
- )
180
- return compiled_model
181
-
182
- @classmethod
183
- def update_kwargs(cls, kwargs):
184
- """
185
- Update user-given kwargs to get proper pytorch model.
186
-
187
- For example, `torchscript`=True should be set because torch.jit
188
- does not support `transformers` output instances as module output;
189
- """
190
- kwargs.update(
191
- {
192
- "torchscript": True,
193
- "return_dict": False,
194
- "use_cache": True,
195
- "torch_dtype": torch.float32,
196
- "_attn_implementation": "eager",
197
- }
198
- )
199
- return kwargs
200
-
201
- @classmethod
202
- def get_pytorch_model(
203
- cls,
204
- model_id: str,
205
- use_auth_token: Optional[Union[bool, str]] = None,
206
- revision: Optional[str] = None,
207
- force_download: bool = False,
208
- cache_dir: Optional[str] = None,
209
- subfolder: str = "",
210
- local_files_only: bool = False,
211
- trust_remote_code: bool = False,
212
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
213
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
214
- **kwargs,
215
- ) -> PreTrainedModel:
216
- if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
217
- config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
218
- if hf_position_embedding := getattr(config, "max_position_embeddings", None):
219
- if hf_position_embedding < rbln_max_seq_len:
220
- logger.warning(
221
- f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
222
- "This may lead to incorrect inferences of the model."
223
- )
224
- kwargs.update({"max_position_embeddings": rbln_max_seq_len})
225
-
226
- return super().get_pytorch_model(
227
- model_id=model_id,
228
- use_auth_token=use_auth_token,
229
- revision=revision,
230
- force_download=force_download,
231
- cache_dir=cache_dir,
232
- subfolder=subfolder,
233
- local_files_only=local_files_only,
234
- trust_remote_code=trust_remote_code,
235
- rbln_config_kwargs=rbln_config_kwargs,
236
- rbln_constructor_kwargs=rbln_constructor_kwargs,
237
- ignore_mismatched_sizes=True,
238
- **kwargs,
239
- )
240
-
241
- @classmethod
242
- def _get_rbln_config(
243
- cls,
244
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
245
- model_config: "PretrainedConfig",
246
- rbln_prefill_chunk_size: Optional[int] = 128,
247
- rbln_max_seq_len: Optional[int] = None,
248
- rbln_batch_size: Optional[int] = None,
249
- ) -> RBLNConfig:
250
- meta = {}
251
- if rbln_max_seq_len is None:
252
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
253
-
254
- if rbln_max_seq_len is None:
255
- for tokenizer in preprocessors:
256
- if hasattr(tokenizer, "model_max_length"):
257
- rbln_max_seq_len = tokenizer.model_max_length
258
- break
259
- if rbln_max_seq_len is None:
260
- raise ValueError("`rbln_max_seq_len` should be specified!")
261
-
262
- if rbln_batch_size is None:
263
- rbln_batch_size = 1
264
-
265
- meta["rbln_prefill_chunk_size"] = rbln_prefill_chunk_size
266
- meta["rbln_max_seq_len"] = rbln_max_seq_len
267
- meta["rbln_batch_size"] = rbln_batch_size if rbln_batch_size is not None else 1
268
-
269
- def get_input_info(query_length):
270
- input_info = [
271
- ("input_ids", [rbln_batch_size, query_length], "int64"),
272
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
273
- (
274
- "cache_position",
275
- [],
276
- "int32",
277
- ),
278
- ]
279
- input_info.extend(
280
- [
281
- (
282
- f"past_key_values_{i}",
283
- [
284
- rbln_batch_size,
285
- model_config.n_head,
286
- rbln_max_seq_len,
287
- model_config.hidden_size // model_config.n_head,
288
- ],
289
- "float32",
290
- )
291
- for i in range(model_config.n_layer * 2)
292
- ]
293
- )
294
- return input_info
295
-
296
- # model input info
297
- prefill_input_info = get_input_info(query_length=rbln_prefill_chunk_size)
298
- dec_input_info = get_input_info(query_length=1)
299
-
300
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
301
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
302
-
303
- dec_rbln_runtime_config.batch_size = rbln_batch_size
304
-
305
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
306
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
307
- _rbln_meta=meta,
308
- )
309
-
310
- return rbln_config
311
-
312
- @classmethod
313
- def _create_runtimes(
314
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
315
- ) -> List[rebel.Runtime]:
316
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
317
- return [
318
- compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
319
- compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
320
- ]
321
-
322
- def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
323
- batch_size, cur_len = input_ids.shape
324
- past_cached_length = past_key_values
325
-
326
- if past_cached_length == 0:
327
- mod_len = cur_len % self.prefill_chunk_size
328
- self.pad_len = self.prefill_chunk_size - mod_len if mod_len > 0 else 0
329
-
330
- prompt_attn_mask = torch.nn.functional.pad(attention_mask, (self.pad_len, 0), value=0)
331
- self.prompt_attn_mask = prompt_attn_mask.reshape(batch_size, 1, 1, -1).contiguous()
332
-
333
- input_ids = torch.nn.functional.pad(input_ids, (self.pad_len, 0), value=0)
334
- attention_mask = self.prefill_attention_mask.clone()
335
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
336
-
337
- query_length = cur_len + self.pad_len
338
- else:
339
- attention_mask = torch.nn.functional.pad(
340
- attention_mask, (self.pad_len, self.max_seq_len - cur_len - self.pad_len)
341
- )
342
- attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
343
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
344
- input_ids = input_ids[:, -1:].contiguous()
345
- query_length = 1
346
-
347
- model_inputs = {
348
- "input_ids": input_ids,
349
- "past_key_values": past_cached_length,
350
- "attention_mask": attention_mask,
351
- "cache_position": cache_position,
352
- "query_length": query_length,
353
- }
354
-
355
- return model_inputs
356
-
357
- def forward(
358
- self,
359
- input_ids: Optional[torch.LongTensor] = None,
360
- past_key_values: int = None,
361
- attention_mask: Optional[torch.FloatTensor] = None,
362
- cache_position: Optional[torch.Tensor] = None,
363
- query_length: Optional[torch.Tensor] = None,
364
- **kwargs,
365
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
366
- past_cached_length = past_key_values
367
-
368
- if past_cached_length is not None:
369
- past_cached_length += query_length
370
-
371
- if cache_position == 0:
372
- for step in range(0, query_length, self.prefill_chunk_size):
373
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
374
- attention_mask[:, :, :, :step] = 1
375
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
376
- attention_mask[:, :, :, :query_length] *= self.prompt_attn_mask
377
-
378
- output = self.prefill_decoder(
379
- input_ids=sliced_input_ids.contiguous(),
380
- attention_mask=attention_mask,
381
- cache_position=cache_position + step,
382
- )
383
- cache_position += self.prefill_chunk_size
384
- else:
385
- output = self.decoder(
386
- input_ids=input_ids.contiguous(),
387
- attention_mask=attention_mask,
388
- cache_position=cache_position,
389
- )
390
- return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_mistral import RBLNMistralForCausalLM
@@ -0,0 +1,29 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+
25
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
+
27
+
28
+ class MistralForCausalLMWrapper(DecoderOnlyWrapper):
29
+ pass
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Callable
27
+
28
+ from transformers import MistralForCausalLM
29
+
30
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
+ from .mistral_architecture import MistralForCausalLMWrapper
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from transformers import PreTrainedModel
36
+
37
+ from ....modeling_config import RBLNConfig
38
+
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
44
+ """
45
+ The Llama Model transformer with a language modeling head (linear layer) on top.
46
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
47
+
48
+ A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
49
+ It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
50
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
51
+ - compiling the resulting graph using the RBLN compiler.
52
+ """
53
+
54
+ @classmethod
55
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
+ rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
57
+ return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
58
+
59
+ def __getattr__(self, __name: str) -> Any:
60
+ def redirect(func):
61
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
62
+
63
+ val = getattr(MistralForCausalLM, __name)
64
+
65
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
66
+ return redirect(val)
67
+
68
+ return val
@@ -70,7 +70,7 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
70
70
  auto_model_class = AutoModelForMaskedLM
71
71
 
72
72
  @classmethod
73
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
73
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
74
74
  return _Wav2Vec2(model).eval()
75
75
 
76
76
  @classmethod
@@ -57,7 +57,6 @@ class _WhisperAttention(WhisperAttention):
57
57
  cache_position: Optional[torch.Tensor] = None,
58
58
  **kwargs,
59
59
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
60
-
61
60
  bsz, tgt_len, _ = hidden_states.size()
62
61
  is_cross_attention = key_value_states is not None
63
62
 
@@ -123,7 +122,6 @@ class _WhisperSdpaAttention(WhisperSdpaAttention):
123
122
  cache_position: Optional[torch.Tensor] = None,
124
123
  **kwargs,
125
124
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
126
-
127
125
  bsz, tgt_len, _ = hidden_states.size()
128
126
 
129
127
  is_cross_attention = key_value_states is not None
@@ -189,7 +187,6 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
189
187
  cache_position: Optional[torch.Tensor] = None,
190
188
  attn_impl: str = "eager",
191
189
  ) -> torch.Tensor:
192
-
193
190
  # Self Attention Block
194
191
  residual = hidden_states
195
192
  hidden_states = self.self_attn_layer_norm(hidden_states)
@@ -248,7 +245,6 @@ class _WhisperDecoder(WhisperDecoder):
248
245
  attn_impl: str = "eager",
249
246
  **kwargs,
250
247
  ):
251
-
252
248
  input_shape = input_ids.size()
253
249
  input_ids = input_ids.view(-1, input_shape[-1])
254
250
 
@@ -312,7 +308,6 @@ class _WhisperDecoderWrapper(torch.nn.Module):
312
308
  self_kv_cache: torch.Tensor,
313
309
  cross_kv_cache: torch.Tensor,
314
310
  ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
315
-
316
311
  # prepare past_key_values
317
312
  kv_cache = ()
318
313
  for i in range(0, self.num_layers * 2, 2):
@@ -367,7 +362,6 @@ class _WhisperEncoderWrapper(torch.nn.Module):
367
362
  self,
368
363
  input_features: Optional[torch.LongTensor] = None,
369
364
  ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
370
-
371
365
  encoder_outputs = self.encoder(input_features=input_features)
372
366
  last_hidden_states = encoder_outputs[0]
373
367
 
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_xlm_roberta import RBLNXLMRobertaModel