optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,17 +23,12 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
27
27
 
28
- import rebel
29
- import torch
30
- from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
31
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
28
+ from transformers import PretrainedConfig, PreTrainedModel
32
29
 
33
- from ....modeling_base import RBLNModel
34
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
35
- from ....utils.runtime_utils import RBLNPytorchRuntime
36
- from ...generation.utils import RBLNGenerationMixin
30
+ from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
31
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
37
32
  from .hf_hub_cached.modeling_midm import MidmLMHeadModel
38
33
  from .midm_architecture import (
39
34
  MidmLMHeadModelWrapper,
@@ -41,7 +36,6 @@ from .midm_architecture import (
41
36
 
42
37
 
43
38
  logger = logging.getLogger(__name__)
44
-
45
39
  if TYPE_CHECKING:
46
40
  from transformers import (
47
41
  AutoFeatureExtractor,
@@ -51,31 +45,12 @@ if TYPE_CHECKING:
51
45
  )
52
46
 
53
47
 
54
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
55
- mandatory_members = ["main_input_name"]
56
-
57
- # RBLN_Runtimemodule
58
- def forward(
59
- self,
60
- input_ids: torch.LongTensor = None,
61
- attention_mask: torch.LongTensor = None,
62
- cache_position: torch.Tensor = None,
63
- **kwargs: Dict[str, Any],
64
- ):
65
- logits = super().forward(
66
- input_ids=input_ids,
67
- attention_mask=attention_mask,
68
- cache_position=cache_position,
69
- )
70
- return logits
71
-
72
-
73
- class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
48
+ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
74
49
  """
75
50
  The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
76
51
  embeddings).
77
52
 
78
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
53
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
79
54
  library implements for all its model.
80
55
 
81
56
  It implements the methods to convert a pre-trained transformers Midm model into a RBLN transformer model by:
@@ -84,46 +59,9 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
84
59
 
85
60
  """
86
61
 
87
- model_type = "rbln_model"
88
- auto_model_class = AutoModelForCausalLM
89
- main_input_name = "input_ids"
90
-
91
- def __init__(
92
- self,
93
- models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
94
- config: PretrainedConfig = None,
95
- preprocessors: Optional[List] = None,
96
- rbln_config: Optional[RBLNConfig] = None,
97
- rbln_device: Optional[List[int]] = None,
98
- rbln_device_map: Optional[Dict[str, int]] = None,
99
- **kwargs,
100
- ):
101
- super().__init__(
102
- models,
103
- config,
104
- preprocessors,
105
- rbln_config,
106
- rbln_device=rbln_device,
107
- rbln_device_map=rbln_device_map,
108
- **kwargs,
109
- )
110
- self.batch_size = self.rbln_config.meta["rbln_batch_size"]
111
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
112
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
113
-
114
- self.prefill_attention_mask = torch.zeros(
115
- self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
116
- )
117
- self.causal_mask = 1 - torch.triu(
118
- torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
119
- )
120
-
121
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0], main_input_name="input_ids")
122
- self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
123
- self.past_cached_length = 0
124
-
125
- def can_generate(self):
126
- return True
62
+ @classmethod
63
+ def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
64
+ return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
127
65
 
128
66
  def __getattr__(self, __name: str) -> Any:
129
67
  """This is the key method to implement RBLN-Midm.
@@ -140,142 +78,46 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
140
78
  return redirect(val)
141
79
  return val
142
80
 
143
- def _reorder_cache(self, past_key_values, beam_idx):
144
- # TODO(jongho): implement
145
- raise NotImplementedError
146
-
147
- @classmethod
148
- @torch.inference_mode()
149
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
150
- wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
151
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
152
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
153
-
154
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
155
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
156
-
157
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
158
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
159
-
160
- prefill_ir = rebel.torchscript_to_ir(
161
- prefill_scripted_model,
162
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
163
- )
164
- dec_ir = rebel.torchscript_to_ir(
165
- dec_scripted_model,
166
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
167
- )
168
-
169
- connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
170
-
171
- compiled_model = rebel.compile(
172
- prefill_ir,
173
- dec_ir,
174
- connections=connections,
175
- fusion=prefill_rbln_runtime_config.fusion,
176
- npu=prefill_rbln_runtime_config.npu,
177
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
178
- use_weight_sharing=True,
179
- )
180
- return compiled_model
181
-
182
- @classmethod
183
- def update_kwargs(cls, kwargs):
184
- """
185
- Update user-given kwargs to get proper pytorch model.
186
-
187
- For example, `torchscript`=True should be set because torch.jit
188
- does not support `transformers` output instances as module output;
189
- """
190
- kwargs.update(
191
- {
192
- "torchscript": True,
193
- "return_dict": False,
194
- "use_cache": True,
195
- "torch_dtype": torch.float32,
196
- "_attn_implementation": "eager",
197
- }
198
- )
199
- return kwargs
200
-
201
- @classmethod
202
- def get_pytorch_model(
203
- cls,
204
- model_id: str,
205
- use_auth_token: Optional[Union[bool, str]] = None,
206
- revision: Optional[str] = None,
207
- force_download: bool = False,
208
- cache_dir: Optional[str] = None,
209
- subfolder: str = "",
210
- local_files_only: bool = False,
211
- trust_remote_code: bool = False,
212
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
213
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
214
- **kwargs,
215
- ) -> PreTrainedModel:
216
- if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
217
- config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
218
- if hf_position_embedding := getattr(config, "max_position_embeddings", None):
219
- if hf_position_embedding < rbln_max_seq_len:
220
- logger.warning(
221
- f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
222
- "This may lead to incorrect inferences of the model."
223
- )
224
- kwargs.update({"max_position_embeddings": rbln_max_seq_len})
225
-
226
- return super().get_pytorch_model(
227
- model_id=model_id,
228
- use_auth_token=use_auth_token,
229
- revision=revision,
230
- force_download=force_download,
231
- cache_dir=cache_dir,
232
- subfolder=subfolder,
233
- local_files_only=local_files_only,
234
- trust_remote_code=trust_remote_code,
235
- rbln_config_kwargs=rbln_config_kwargs,
236
- rbln_constructor_kwargs=rbln_constructor_kwargs,
237
- ignore_mismatched_sizes=True,
238
- **kwargs,
239
- )
240
-
241
81
  @classmethod
242
82
  def _get_rbln_config(
243
83
  cls,
244
84
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
245
85
  model_config: "PretrainedConfig",
246
- rbln_prefill_chunk_size: Optional[int] = 128,
247
86
  rbln_max_seq_len: Optional[int] = None,
248
87
  rbln_batch_size: Optional[int] = None,
88
+ **kwargs,
249
89
  ) -> RBLNConfig:
250
90
  meta = {}
251
- if rbln_max_seq_len is None:
252
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
253
91
 
92
+ prefill_chunk_size = 128
254
93
  if rbln_max_seq_len is None:
255
- for tokenizer in preprocessors:
256
- if hasattr(tokenizer, "model_max_length"):
257
- rbln_max_seq_len = tokenizer.model_max_length
258
- break
259
- if rbln_max_seq_len is None:
260
- raise ValueError("`rbln_max_seq_len` should be specified!")
94
+ rbln_max_seq_len = getattr(model_config, "n_positions", None)
95
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
261
96
 
262
- if rbln_batch_size is None:
263
- rbln_batch_size = 1
264
-
265
- meta["rbln_prefill_chunk_size"] = rbln_prefill_chunk_size
266
97
  meta["rbln_max_seq_len"] = rbln_max_seq_len
267
- meta["rbln_batch_size"] = rbln_batch_size if rbln_batch_size is not None else 1
268
-
269
- def get_input_info(query_length):
98
+ meta["rbln_batch_size"] = rbln_batch_size
99
+ meta["rbln_prefill_chunk_size"] = prefill_chunk_size
100
+
101
+ def get_input_info(
102
+ batch_size,
103
+ query_length,
104
+ ):
105
+ head_dim = (
106
+ model_config.head_dim
107
+ if hasattr(model_config, "head_dim")
108
+ else model_config.hidden_size // model_config.n_head
109
+ )
270
110
  input_info = [
271
- ("input_ids", [rbln_batch_size, query_length], "int64"),
272
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
111
+ ("input_ids", [batch_size, query_length], "int64"),
112
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
273
113
  (
274
114
  "cache_position",
275
- [],
115
+ [batch_size, query_length],
276
116
  "int32",
277
117
  ),
118
+ ("batch_position", [], "int16"),
278
119
  ]
120
+
279
121
  input_info.extend(
280
122
  [
281
123
  (
@@ -284,18 +126,24 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
284
126
  rbln_batch_size,
285
127
  model_config.n_head,
286
128
  rbln_max_seq_len,
287
- model_config.hidden_size // model_config.n_head,
129
+ head_dim,
288
130
  ],
289
131
  "float32",
290
132
  )
291
133
  for i in range(model_config.n_layer * 2)
292
134
  ]
293
135
  )
136
+
294
137
  return input_info
295
138
 
296
- # model input info
297
- prefill_input_info = get_input_info(query_length=rbln_prefill_chunk_size)
298
- dec_input_info = get_input_info(query_length=1)
139
+ prefill_input_info = get_input_info(
140
+ batch_size=1,
141
+ query_length=prefill_chunk_size,
142
+ )
143
+ dec_input_info = get_input_info(
144
+ batch_size=rbln_batch_size,
145
+ query_length=1,
146
+ )
299
147
 
300
148
  prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
301
149
  dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
@@ -308,83 +156,3 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
308
156
  )
309
157
 
310
158
  return rbln_config
311
-
312
- @classmethod
313
- def _create_runtimes(
314
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
315
- ) -> List[rebel.Runtime]:
316
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
317
- return [
318
- compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
319
- compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
320
- ]
321
-
322
- def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
323
- batch_size, cur_len = input_ids.shape
324
- past_cached_length = past_key_values
325
-
326
- if past_cached_length == 0:
327
- mod_len = cur_len % self.prefill_chunk_size
328
- self.pad_len = self.prefill_chunk_size - mod_len if mod_len > 0 else 0
329
-
330
- prompt_attn_mask = torch.nn.functional.pad(attention_mask, (self.pad_len, 0), value=0)
331
- self.prompt_attn_mask = prompt_attn_mask.reshape(batch_size, 1, 1, -1).contiguous()
332
-
333
- input_ids = torch.nn.functional.pad(input_ids, (self.pad_len, 0), value=0)
334
- attention_mask = self.prefill_attention_mask.clone()
335
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
336
-
337
- query_length = cur_len + self.pad_len
338
- else:
339
- attention_mask = torch.nn.functional.pad(
340
- attention_mask, (self.pad_len, self.max_seq_len - cur_len - self.pad_len)
341
- )
342
- attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
343
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
344
- input_ids = input_ids[:, -1:].contiguous()
345
- query_length = 1
346
-
347
- model_inputs = {
348
- "input_ids": input_ids,
349
- "past_key_values": past_cached_length,
350
- "attention_mask": attention_mask,
351
- "cache_position": cache_position,
352
- "query_length": query_length,
353
- }
354
-
355
- return model_inputs
356
-
357
- def forward(
358
- self,
359
- input_ids: Optional[torch.LongTensor] = None,
360
- past_key_values: int = None,
361
- attention_mask: Optional[torch.FloatTensor] = None,
362
- cache_position: Optional[torch.Tensor] = None,
363
- query_length: Optional[torch.Tensor] = None,
364
- **kwargs,
365
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
366
- past_cached_length = past_key_values
367
-
368
- if past_cached_length is not None:
369
- past_cached_length += query_length
370
-
371
- if cache_position == 0:
372
- for step in range(0, query_length, self.prefill_chunk_size):
373
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
374
- attention_mask[:, :, :, :step] = 1
375
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
376
- attention_mask[:, :, :, :query_length] *= self.prompt_attn_mask
377
-
378
- output = self.prefill_decoder(
379
- input_ids=sliced_input_ids.contiguous(),
380
- attention_mask=attention_mask,
381
- cache_position=cache_position + step,
382
- )
383
- cache_position += self.prefill_chunk_size
384
- else:
385
- output = self.decoder(
386
- input_ids=input_ids.contiguous(),
387
- attention_mask=attention_mask,
388
- cache_position=cache_position,
389
- )
390
- return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
@@ -57,7 +57,6 @@ class _WhisperAttention(WhisperAttention):
57
57
  cache_position: Optional[torch.Tensor] = None,
58
58
  **kwargs,
59
59
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
60
-
61
60
  bsz, tgt_len, _ = hidden_states.size()
62
61
  is_cross_attention = key_value_states is not None
63
62
 
@@ -123,7 +122,6 @@ class _WhisperSdpaAttention(WhisperSdpaAttention):
123
122
  cache_position: Optional[torch.Tensor] = None,
124
123
  **kwargs,
125
124
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
126
-
127
125
  bsz, tgt_len, _ = hidden_states.size()
128
126
 
129
127
  is_cross_attention = key_value_states is not None
@@ -189,7 +187,6 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
189
187
  cache_position: Optional[torch.Tensor] = None,
190
188
  attn_impl: str = "eager",
191
189
  ) -> torch.Tensor:
192
-
193
190
  # Self Attention Block
194
191
  residual = hidden_states
195
192
  hidden_states = self.self_attn_layer_norm(hidden_states)
@@ -248,7 +245,6 @@ class _WhisperDecoder(WhisperDecoder):
248
245
  attn_impl: str = "eager",
249
246
  **kwargs,
250
247
  ):
251
-
252
248
  input_shape = input_ids.size()
253
249
  input_ids = input_ids.view(-1, input_shape[-1])
254
250
 
@@ -312,7 +308,6 @@ class _WhisperDecoderWrapper(torch.nn.Module):
312
308
  self_kv_cache: torch.Tensor,
313
309
  cross_kv_cache: torch.Tensor,
314
310
  ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
315
-
316
311
  # prepare past_key_values
317
312
  kv_cache = ()
318
313
  for i in range(0, self.num_layers * 2, 2):
@@ -367,7 +362,6 @@ class _WhisperEncoderWrapper(torch.nn.Module):
367
362
  self,
368
363
  input_features: Optional[torch.LongTensor] = None,
369
364
  ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
370
-
371
365
  encoder_outputs = self.encoder(input_features=input_features)
372
366
  last_hidden_states = encoder_outputs[0]
373
367
 
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_xlm_roberta import RBLNXLMRobertaModel
@@ -0,0 +1,125 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
26
+
27
+ import torch
28
+ from transformers import AutoModel, PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
29
+
30
+ from ....modeling_base import RBLNModel
31
+ from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
38
+
39
+ class RBLNXLMRobertaModel(RBLNModel):
40
+ auto_model_class = AutoModel # feature extraction
41
+ original_model_class = XLMRobertaModel
42
+ original_config_class = XLMRobertaConfig
43
+
44
+ @classmethod
45
+ def get_pytorch_model(
46
+ cls,
47
+ model_id: str,
48
+ use_auth_token: Optional[Union[bool, str]] = None,
49
+ revision: Optional[str] = None,
50
+ force_download: bool = False,
51
+ cache_dir: Optional[str] = None,
52
+ subfolder: str = "",
53
+ local_files_only: bool = False,
54
+ trust_remote_code: bool = False,
55
+ rbln_config_kwargs: Optional[Dict[str, Any]] = None,
56
+ rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
57
+ **kwargs,
58
+ ) -> "PreTrainedModel":
59
+ model: "PreTrainedModel" = super().get_pytorch_model(
60
+ model_id=model_id,
61
+ use_auth_token=use_auth_token,
62
+ revision=revision,
63
+ force_download=force_download,
64
+ cache_dir=cache_dir,
65
+ subfolder=subfolder,
66
+ local_files_only=local_files_only,
67
+ trust_remote_code=trust_remote_code,
68
+ rbln_config_kwargs=rbln_config_kwargs,
69
+ rbln_constructor_kwargs=rbln_constructor_kwargs,
70
+ library_name="transformers",
71
+ )
72
+
73
+ return model
74
+
75
+ @classmethod
76
+ def _get_rbln_config(
77
+ cls,
78
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
79
+ model_config: Optional["PretrainedConfig"] = None,
80
+ rbln_max_seq_len: Optional[int] = None,
81
+ rbln_model_input_names: Optional[List[str]] = None,
82
+ rbln_batch_size: Optional[int] = None,
83
+ ) -> RBLNConfig:
84
+
85
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
86
+ model_config, "max_position_embeddings", None
87
+ )
88
+
89
+ if rbln_max_seq_len is None:
90
+ rbln_max_seq_len = max_position_embeddings
91
+ if rbln_max_seq_len is None:
92
+ for tokenizer in preprocessors:
93
+ if hasattr(tokenizer, "model_max_length"):
94
+ rbln_max_seq_len = tokenizer.model_max_length
95
+ break
96
+ if rbln_max_seq_len is None:
97
+ raise ValueError("`rbln_max_seq_len` should be specified!")
98
+
99
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
100
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
101
+
102
+ if rbln_model_input_names is None:
103
+ # These are BERT's inputs
104
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
105
+
106
+ if rbln_batch_size is None:
107
+ rbln_batch_size = 1
108
+
109
+ input_info = [
110
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
111
+ for model_input_name in rbln_model_input_names
112
+ ]
113
+
114
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
115
+ rbln_runtime_config.batch_size = rbln_batch_size
116
+
117
+ meta = {"rbln_max_seq_len": rbln_max_seq_len}
118
+
119
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
120
+
121
+ def forward(self, input_ids: "torch.Tensor", attention_mask: "torch.Tensor", token_type_ids: "torch.Tensor" = None, **kwargs):
122
+ if token_type_ids is None:
123
+ token_type_ids = torch.zeros_like(input=input_ids, dtype=torch.int64)
124
+ output = super().forward(input_ids, attention_mask, token_type_ids)
125
+ return output
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: optimum-rbln
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators.
5
5
  It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
6
6
  Keywords: transformers,diffusers,inference,rbln,atom,rebel
@@ -21,7 +21,7 @@ Project-URL: Homepage, https://rebellions.ai
21
21
  Project-URL: Documentation, https://docs.rbln.ai
22
22
  Requires-Python: <3.11,>=3.8
23
23
  Requires-Dist: torch<=2.2.1
24
- Requires-Dist: optimum>=1.17.1
24
+ Requires-Dist: optimum<=1.20.0
25
25
  Requires-Dist: accelerate>=0.28.0
26
26
  Requires-Dist: transformers<=4.40.2
27
27
  Requires-Dist: diffusers<=0.29.2
@@ -35,7 +35,6 @@ Requires-Dist: sentencepiece>=0.2.0; extra == "tests"
35
35
  Requires-Dist: datasets>=2.18.0; extra == "tests"
36
36
  Requires-Dist: sacremoses>=0.1.1; extra == "tests"
37
37
  Requires-Dist: safetensors>=0.4.2; extra == "tests"
38
- Requires-Dist: black>=24.3.0; extra == "quality"
39
38
  Requires-Dist: ruff>=0.3.3; extra == "quality"
40
39
  Requires-Dist: isort>=5.13.2; extra == "quality"
41
40
  Requires-Dist: hf-doc-builder>=0.5.0; extra == "quality"