optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,13 +23,10 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
27
 
30
28
  import rebel
31
29
  import torch
32
- from optimum.exporters import TasksManager
33
30
  from transformers import (
34
31
  AutoModelForSeq2SeqLM,
35
32
  BartConfig,
@@ -39,12 +36,11 @@ from transformers import (
39
36
  )
40
37
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
41
38
 
42
- from .modeling_base import RBLNBaseModel
39
+ from .modeling_base import RBLNModel
43
40
  from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
44
41
  from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
45
42
  from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
46
43
  from .utils.runtime_utils import RBLNPytorchRuntime
47
- from .utils.save_utils import maybe_save_preprocessors
48
44
 
49
45
 
50
46
  logger = logging.getLogger(__name__)
@@ -75,7 +71,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
75
71
  return Seq2SeqLMOutput(logits=outputs)
76
72
 
77
73
 
78
- class RBLNModelForSeq2SeqLM(RBLNBaseModel):
74
+ class RBLNModelForSeq2SeqLM(RBLNModel):
79
75
  """
80
76
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
81
77
  This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -88,7 +84,6 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
88
84
  Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
89
85
  """
90
86
 
91
- model_type = "rbln_model"
92
87
  auto_model_class = AutoModelForSeq2SeqLM
93
88
 
94
89
  def __post_init__(self, **kwargs):
@@ -97,8 +92,8 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
97
92
  self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
98
93
  self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
99
94
  self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
100
- self.encoder = RBLNRuntimeEncoder(runtime=self.runtimes[0], main_input_name="input_ids")
101
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
95
+ self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
96
+ self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
102
97
 
103
98
  def can_generate(self):
104
99
  return True
@@ -149,74 +144,18 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
149
144
  }
150
145
 
151
146
  @classmethod
152
- def _export(
153
- cls,
154
- model_id: str,
155
- config: "PretrainedConfig",
156
- use_auth_token: Optional[Union[bool, str]] = None,
157
- revision: Optional[str] = None,
158
- force_download: bool = False,
159
- cache_dir: Optional[str] = None,
160
- subfolder: str = "",
161
- local_files_only: bool = False,
162
- trust_remote_code: bool = False,
163
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
164
- **kwargs,
165
- ) -> "AutoModelForSeq2SeqLM":
166
- """
167
- Exports a vanilla Transformers model into a rbln-compiled Module.
168
- """
169
- task = kwargs.pop("task", None)
170
- if task is None:
171
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
172
-
173
- if model_save_dir is None:
174
- save_dir = TemporaryDirectory()
175
- save_dir_path = Path(save_dir.name)
176
- else:
177
- save_dir = model_save_dir
178
- if isinstance(save_dir, TemporaryDirectory):
179
- save_dir_path = Path(model_save_dir.name)
180
- else:
181
- save_dir_path = Path(model_save_dir)
182
- save_dir_path.mkdir(exist_ok=True)
183
-
147
+ def update_kwargs(cls, kwargs):
184
148
  kwargs.update(
185
149
  {
186
150
  "torchscript": True,
187
151
  "return_dict": False,
188
- "use_cache": False,
152
+ "use_cache": True,
189
153
  }
190
154
  )
155
+ return kwargs
191
156
 
192
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
193
-
194
- model: AutoModelForSeq2SeqLM = TasksManager.get_model_from_task(
195
- task=task,
196
- model_name_or_path=model_id,
197
- subfolder=subfolder,
198
- revision=revision,
199
- framework="pt",
200
- cache_dir=cache_dir,
201
- use_auth_token=use_auth_token,
202
- local_files_only=local_files_only,
203
- force_download=force_download,
204
- trust_remote_code=trust_remote_code,
205
- **kwargs,
206
- )
207
-
208
- if config is None:
209
- config = model.config
210
-
211
- config.save_pretrained(save_dir_path)
212
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
213
-
214
- # Get compilation arguments
215
- if rbln_config_kwargs.get("rbln_config", None) is None:
216
- rbln_config = cls.get_rbln_config(
217
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
218
- )
219
-
157
+ @classmethod
158
+ def get_compiled_model(cls, model, rbln_config: RBLNConfig):
220
159
  def optimized_models(model):
221
160
  if isinstance(model, T5ForConditionalGeneration):
222
161
  encoder_model = T5EncoderWrapper(model).eval()
@@ -229,67 +168,54 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
229
168
 
230
169
  return encoder_model, decoder_model
231
170
 
232
- def compile():
233
- wrapped_encoder, wrapped_decoder = optimized_models(model)
171
+ wrapped_encoder, wrapped_decoder = optimized_models(model)
234
172
 
235
- wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
236
- wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
237
- wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
173
+ wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
174
+ wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
175
+ wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
238
176
 
239
- wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
240
- wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
241
- wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
177
+ wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
178
+ wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
179
+ wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
242
180
 
243
- enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
244
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
181
+ enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
182
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
245
183
 
246
- if isinstance(model, T5ForConditionalGeneration):
247
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
248
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
249
- else:
250
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
251
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
252
-
253
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs)
254
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
255
-
256
- enc_ir = rebel.torchscript_to_ir(
257
- enc_scripted_model,
258
- input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
259
- name=enc_rbln_runtime_config.rbln_mod_name,
260
- )
261
- dec_ir = rebel.torchscript_to_ir(
262
- dec_scripted_model,
263
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
264
- name=dec_rbln_runtime_config.rbln_mod_name,
265
- )
266
- dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
267
-
268
- connections = [
269
- (enc_ir.outputs[0], dec_ir.inputs[5]),
270
- (dec_ir.outputs[1], dec_ir.inputs[4]),
271
- ]
272
- compiled_model = rebel.compile(
273
- enc_ir,
274
- dec_ir,
275
- connections=connections,
276
- fusion=enc_rbln_runtime_config.fusion,
277
- npu=enc_rbln_runtime_config.npu,
278
- tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
279
- )
280
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
281
-
282
- compile()
283
-
284
- rbln_config.save(save_dir_path)
285
-
286
- return cls._from_pretrained(
287
- model_id=save_dir_path,
288
- config=config,
289
- model_save_dir=save_dir,
290
- **rbln_constructor_kwargs,
291
- **kwargs,
184
+ if isinstance(model, T5ForConditionalGeneration):
185
+ enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
186
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
187
+ else:
188
+ enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
189
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
190
+
191
+ enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
192
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
193
+
194
+ enc_ir = rebel.torchscript_to_ir(
195
+ enc_scripted_model,
196
+ input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
197
+ name=enc_rbln_runtime_config.rbln_mod_name,
198
+ )
199
+ dec_ir = rebel.torchscript_to_ir(
200
+ dec_scripted_model,
201
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
202
+ name=dec_rbln_runtime_config.rbln_mod_name,
292
203
  )
204
+ dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
205
+
206
+ connections = [
207
+ (enc_ir.outputs[0], dec_ir.inputs[5]),
208
+ (dec_ir.outputs[1], dec_ir.inputs[4]),
209
+ ]
210
+ compiled_model = rebel.compile(
211
+ enc_ir,
212
+ dec_ir,
213
+ connections=connections,
214
+ fusion=enc_rbln_runtime_config.fusion,
215
+ npu=enc_rbln_runtime_config.npu,
216
+ tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
217
+ )
218
+ return compiled_model
293
219
 
294
220
  @classmethod
295
221
  def _get_rbln_config(
@@ -411,11 +337,14 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
411
337
 
412
338
  return rbln_config
413
339
 
414
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
340
+ @classmethod
341
+ def _create_runtimes(
342
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
343
+ ) -> List[rebel.Runtime]:
415
344
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
416
345
  return [
417
- self.compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
418
- self.compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
346
+ compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
347
+ compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
419
348
  ]
420
349
 
421
350
  def forward(
@@ -436,9 +365,6 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
436
365
 
437
366
  return Seq2SeqLMOutput(logits=lm_logits)
438
367
 
439
- def __repr__(self):
440
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
441
-
442
368
  def _prepare_encoder_decoder_kwargs_for_generation(
443
369
  self,
444
370
  inputs_tensor: torch.Tensor,
@@ -27,28 +27,36 @@ from transformers.utils import _LazyModule
27
27
 
28
28
 
29
29
  _import_structure = {
30
+ "cache_utils": ["RebelDynamicCache"],
30
31
  "generation": ["BatchTextIteratorStreamer"],
31
32
  "models": [
32
33
  "RBLNCLIPTextModel",
33
34
  "RBLNCLIPTextModelWithProjection",
35
+ "RBLNDPTForDepthEstimation",
36
+ "RBLNGemmaForCausalLM",
34
37
  "RBLNGPT2LMHeadModel",
35
38
  "RBLNWav2Vec2ForCTC",
36
39
  "RBLNWhisperForConditionalGeneration",
37
40
  "RBLNLlamaForCausalLM",
38
41
  "RBLNMidmLMHeadModel",
42
+ "RBLNXLMRobertaModel"
39
43
  ],
40
44
  }
41
45
 
42
46
  if TYPE_CHECKING:
47
+ from .cache_utils import RebelDynamicCache
43
48
  from .generation import BatchTextIteratorStreamer
44
49
  from .models import (
45
50
  RBLNCLIPTextModel,
46
51
  RBLNCLIPTextModelWithProjection,
52
+ RBLNDPTForDepthEstimation,
53
+ RBLNGemmaForCausalLM,
47
54
  RBLNGPT2LMHeadModel,
48
55
  RBLNLlamaForCausalLM,
49
56
  RBLNMidmLMHeadModel,
50
57
  RBLNWav2Vec2ForCTC,
51
58
  RBLNWhisperForConditionalGeneration,
59
+ RBLNXLMRobertaModel,
52
60
  )
53
61
  else:
54
62
  import sys
@@ -0,0 +1,111 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from transformers.cache_utils import DynamicCache
5
+
6
+
7
+ class RebelDynamicCache(DynamicCache):
8
+ """
9
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
10
+
11
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
12
+ `[batch_size, num_heads, seq_len, head_dim]`.
13
+ """
14
+
15
+ def __init__(self, current_steps) -> None:
16
+ super().__init__()
17
+ self.current_steps = current_steps
18
+
19
+ def assign(
20
+ self,
21
+ key_states: torch.Tensor,
22
+ value_states: torch.Tensor,
23
+ layer_idx: int,
24
+ ) -> None:
25
+ self.key_cache[layer_idx] = key_states.squeeze(2)
26
+ self.value_cache[layer_idx] = value_states.squeeze(2)
27
+
28
+ def update(
29
+ self,
30
+ key_states: torch.Tensor,
31
+ value_states: torch.Tensor,
32
+ layer_idx: int,
33
+ batch_idx: int,
34
+ read_first_step: Optional[bool] = False,
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ """
37
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` and the batch 'batch_inx'
38
+ based on self.current_step,
39
+ """
40
+ current_step = self.current_steps[0 if read_first_step else batch_idx]
41
+ kend = current_step + key_states.shape[-2]
42
+ vend = current_step + value_states.shape[-2]
43
+ update_key_states = (
44
+ self.key_cache[layer_idx][batch_idx]
45
+ .unsqueeze(0)
46
+ .unsqueeze(2)
47
+ .slice_scatter(key_states, dim=-2, start=current_step, end=kend)
48
+ )
49
+ update_value_states = (
50
+ self.value_cache[layer_idx][batch_idx]
51
+ .unsqueeze(0)
52
+ .unsqueeze(2)
53
+ .slice_scatter(value_states, dim=-2, start=current_step, end=vend)
54
+ )
55
+
56
+ return update_key_states, update_value_states
57
+
58
+ @classmethod
59
+ def from_input_format(cls, position_ids, num_hidden_layer, *past_key_values) -> "DynamicCache":
60
+ """Converts a cache in the rbln cache format (list of past_kv) into an equivalent `DynamicCache`."""
61
+
62
+ batch, _ = position_ids.shape
63
+ current_steps = [position_ids[b][0] for b in range(batch)]
64
+
65
+ assert len(current_steps) == batch
66
+ cache = cls(current_steps)
67
+
68
+ for layer_idx in range(num_hidden_layer):
69
+ key_states = past_key_values[layer_idx * 2]
70
+ value_states = past_key_values[layer_idx * 2 + 1]
71
+ cache.key_cache.append(key_states)
72
+ cache.value_cache.append(value_states)
73
+
74
+ return cache
75
+
76
+
77
+ class RebelDynamicCache_4D(RebelDynamicCache):
78
+ def assign(
79
+ self,
80
+ keys: torch.Tensor,
81
+ values: torch.Tensor,
82
+ layer_idx: int,
83
+ ) -> None:
84
+ self.key_cache[layer_idx] = keys
85
+ self.value_cache[layer_idx] = values
86
+
87
+ def update(
88
+ self,
89
+ keys: torch.Tensor,
90
+ values: torch.Tensor,
91
+ layer_idx: int,
92
+ batch_idx: int,
93
+ read_first_step: Optional[bool] = False,
94
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
95
+ """
96
+ Updates the cache with the new `keys` and `values` for the layer `layer_idx` and the batch 'batch_inx'
97
+ based on self.current_step,
98
+ """
99
+ current_step = self.current_steps[0 if read_first_step else batch_idx]
100
+ kend = current_step + keys.shape[-2]
101
+ vend = current_step + values.shape[-2]
102
+ update_keys = (
103
+ self.key_cache[layer_idx][batch_idx].unsqueeze(0).slice_scatter(keys, dim=-2, start=current_step, end=kend)
104
+ )
105
+ update_values = (
106
+ self.value_cache[layer_idx][batch_idx]
107
+ .unsqueeze(0)
108
+ .slice_scatter(values, dim=-2, start=current_step, end=vend)
109
+ )
110
+
111
+ return update_keys, update_values
@@ -32,7 +32,6 @@ class RBLNGenerationMixin:
32
32
  generation_config: Optional[GenerationConfig] = None, # thkim change for 4.41.0
33
33
  **model_kwargs,
34
34
  ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
35
-
36
35
  ###################### thkim change for 4.41.0 ############################
37
36
  if generation_config is not None:
38
37
  pad_token_id = generation_config.pad_token_id
@@ -216,7 +215,6 @@ class RBLNGenerationMixin:
216
215
  do_sample: Optional[bool] = True,
217
216
  **model_kwargs,
218
217
  ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
219
-
220
218
  ###################### thkim change for 4.41.0 ############################
221
219
  if generation_config is not None:
222
220
  pad_token_id = generation_config.pad_token_id
@@ -22,8 +22,11 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
25
+ from .dpt import RBLNDPTForDepthEstimation
26
+ from .gemma import RBLNGemmaForCausalLM
25
27
  from .gpt2 import RBLNGPT2LMHeadModel
26
28
  from .llama import RBLNLlamaForCausalLM
27
29
  from .midm import RBLNMidmLMHeadModel
28
30
  from .wav2vec2 import RBLNWav2Vec2ForCTC
29
31
  from .whisper import RBLNWhisperForConditionalGeneration
32
+ from .xlm_roberta import RBLNXLMRobertaModel
@@ -56,7 +56,6 @@ class _BartAttention(BartAttention):
56
56
  cache_position: torch.Tensor,
57
57
  key_value_states: Optional[torch.Tensor] = None,
58
58
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
59
-
60
59
  bsz, tgt_len, _ = hidden_states.size()
61
60
  is_cross_attention = key_value_states is not None
62
61
 
@@ -111,7 +110,6 @@ class _BartSdpaAttention(BartSdpaAttention):
111
110
  cache_position: torch.Tensor,
112
111
  key_value_states: Optional[torch.Tensor] = None,
113
112
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
114
-
115
113
  bsz, tgt_len, _ = hidden_states.size()
116
114
  is_cross_attention = key_value_states is not None
117
115
 
@@ -166,7 +164,6 @@ class _BartDecoderLayer(BartDecoderLayer):
166
164
  cache_position: torch.Tensor,
167
165
  attn_impl: str = "eager",
168
166
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
169
-
170
167
  # Self Attention Block
171
168
  residual = hidden_states
172
169
  self_attn_past_key_value = past_key_value[:2]
@@ -218,7 +215,6 @@ class _BartDecoder(BartDecoder):
218
215
  cache_position: torch.Tensor,
219
216
  attn_impl: str = "eager",
220
217
  ):
221
-
222
218
  # embedding
223
219
  positions_idx = cache_position + self.embed_positions.offset
224
220
  positions = self.embed_positions.weight[positions_idx]
@@ -284,7 +280,6 @@ class BartDecoderWrapper(torch.nn.Module):
284
280
  self_kv_cache: torch.Tensor,
285
281
  cross_kv_cache: torch.Tensor,
286
282
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
287
-
288
283
  # prepare past_key_values
289
284
  kv_cache = ()
290
285
  for i in range(0, self.num_layers * 2, 2):
@@ -51,7 +51,6 @@ class _TextEncoder(torch.nn.Module):
51
51
 
52
52
 
53
53
  class RBLNCLIPTextModel(RBLNModel):
54
- model_type = "rbln_clip"
55
54
  auto_model_class = AutoModel # feature extraction
56
55
  original_model_class = CLIPTextModel
57
56
  original_config_class = CLIPTextConfig
@@ -0,0 +1,36 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .decoderonly_architecture import (
25
+ DecoderOnlyAttention,
26
+ DecoderOnlyDecoderLayer,
27
+ DecoderOnlyModel,
28
+ DecoderOnlyWrapper,
29
+ DynamicNTKScalingRotaryEmbedding,
30
+ LinearScalingRotaryEmbedding,
31
+ RotaryEmbedding,
32
+ apply_rotary_pos_emb,
33
+ rotate_half,
34
+ slice_and_unsqueeze_cos_sin,
35
+ )
36
+ from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM