optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -22,17 +22,17 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
- from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
25
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
28
26
 
29
27
  import torch
30
28
  import torch.nn.functional as F
31
- from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
29
+ from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline
32
30
  from diffusers.image_processor import PipelineImageInput
31
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
33
32
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
34
33
  from diffusers.utils import deprecate, logging
35
34
  from diffusers.utils.torch_utils import is_compiled_module
35
+ from transformers import CLIPTextModel
36
36
 
37
37
  from ....modeling_base import RBLNBaseModel
38
38
  from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
@@ -63,103 +63,152 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
63
63
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
64
64
  """
65
65
  export = kwargs.pop("export", None)
66
- text_encoder = kwargs.pop("text_encoder", None)
67
- controlnets = kwargs.pop("controlnet", None)
68
66
  vae = kwargs.pop("vae", None)
67
+ unet = kwargs.pop("unet", None)
68
+ text_encoder = kwargs.pop("text_encoder", None)
69
+ text_encoder_2 = kwargs.pop("text_encoder_2", None)
70
+ controlnet = kwargs.pop("controlnet", None)
71
+ model_save_dir = kwargs.pop("model_save_dir", None)
69
72
 
70
73
  rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
74
+
71
75
  kwargs_dict = {
72
76
  "pretrained_model_name_or_path": model_id,
73
- "vae": vae,
74
- "controlnet": controlnets,
75
- "text_encoder": text_encoder,
76
77
  **kwargs,
77
78
  }
78
79
 
80
+ kwargs_dict.update(
81
+ {
82
+ **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
83
+ **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
84
+ **(
85
+ {"text_encoder": text_encoder}
86
+ if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
87
+ else {}
88
+ ),
89
+ **(
90
+ {"controlnet": controlnet}
91
+ if controlnet is not None
92
+ and (
93
+ isinstance(controlnet, ControlNetModel)
94
+ or all(isinstance(c, ControlNetModel) for c in controlnet)
95
+ )
96
+ else {}
97
+ ),
98
+ }
99
+ )
100
+
79
101
  model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
80
102
 
81
103
  if export is None or export is False:
82
104
  return model
83
105
 
84
- save_dir = TemporaryDirectory()
85
- save_dir_path = Path(save_dir.name)
86
-
87
- model.save_pretrained(save_directory=save_dir_path, **kwargs)
88
-
89
106
  do_classifier_free_guidance = (
90
107
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
91
108
  )
92
109
 
93
- vae = RBLNAutoencoderKL.from_pretrained(
94
- model_id=model_id,
95
- subfolder="vae",
96
- export=True,
97
- rbln_unet_sample_size=model.unet.config.sample_size,
98
- rbln_use_encode=True,
99
- rbln_vae_scale_factor=model.vae_scale_factor,
100
- **rbln_config_kwargs,
101
- **rbln_constructor_kwargs,
102
- )
103
- text_encoder = RBLNCLIPTextModel.from_pretrained(
104
- model_id=model_id,
105
- subfolder="text_encoder",
106
- export=True,
107
- **rbln_config_kwargs,
108
- **rbln_constructor_kwargs,
109
- )
110
- text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
111
- model_id=model_id,
112
- subfolder="text_encoder_2",
113
- export=True,
114
- **rbln_config_kwargs,
115
- **rbln_constructor_kwargs,
116
- )
117
-
118
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
119
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
110
+ if not isinstance(vae, RBLNAutoencoderKL):
111
+ vae = RBLNAutoencoderKL.from_pretrained(
112
+ model_id=model_id,
113
+ subfolder="vae",
114
+ export=True,
115
+ model_save_dir=model_save_dir,
116
+ rbln_unet_sample_size=model.unet.config.sample_size,
117
+ rbln_use_encode=True,
118
+ rbln_vae_scale_factor=model.vae_scale_factor,
119
+ **rbln_config_kwargs,
120
+ **rbln_constructor_kwargs,
121
+ )
120
122
 
121
- unet = RBLNUNet2DConditionModel.from_pretrained(
122
- model_id=model_id,
123
- subfolder="unet",
124
- export=True,
125
- rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
126
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
127
- rbln_batch_size=unet_batch_size,
128
- rbln_use_encode=True,
129
- rbln_vae_scale_factor=model.vae_scale_factor,
130
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
131
- **rbln_config_kwargs,
132
- **rbln_constructor_kwargs,
133
- )
123
+ if not isinstance(text_encoder, RBLNCLIPTextModel):
124
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
125
+ model_id=model_id,
126
+ subfolder="text_encoder",
127
+ export=True,
128
+ model_save_dir=model_save_dir,
129
+ **rbln_config_kwargs,
130
+ **rbln_constructor_kwargs,
131
+ )
134
132
 
135
- if isinstance(controlnets, (list, tuple)):
136
- controlnet = RBLNMultiControlNetModel.from_pretrained(
137
- model_id=str(save_dir_path / "controlnet"),
133
+ if not isinstance(text_encoder_2, RBLNCLIPTextModel):
134
+ text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
135
+ model_id=model_id,
136
+ subfolder="text_encoder_2",
138
137
  export=True,
139
- rbln_batch_size=unet_batch_size,
140
- rbln_vae_scale_factor=model.vae_scale_factor,
138
+ model_save_dir=model_save_dir,
141
139
  **rbln_config_kwargs,
142
140
  **rbln_constructor_kwargs,
143
141
  )
144
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
145
- else:
146
- controlnet = RBLNControlNetModel.from_pretrained(
147
- model_id=save_dir_path / "controlnet",
142
+
143
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
144
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
145
+
146
+ if not isinstance(unet, RBLNUNet2DConditionModel):
147
+ unet = RBLNUNet2DConditionModel.from_pretrained(
148
+ model_id=model_id,
149
+ subfolder="unet",
148
150
  export=True,
149
- rbln_batch_size=unet_batch_size,
151
+ model_save_dir=model_save_dir,
152
+ rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
150
153
  rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
154
+ rbln_batch_size=unet_batch_size,
155
+ rbln_use_encode=True,
151
156
  rbln_vae_scale_factor=model.vae_scale_factor,
157
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
152
158
  **rbln_config_kwargs,
153
159
  **rbln_constructor_kwargs,
154
160
  )
155
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
156
161
 
162
+ if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
163
+ if isinstance(controlnet, (list, tuple)):
164
+ multicontrolnet = []
165
+ for i, cid in enumerate(controlnet):
166
+ subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
167
+ multicontrolnet.append(
168
+ RBLNControlNetModel.from_pretrained(
169
+ model_id=cid.config._name_or_path,
170
+ subfolder=subfolder_name,
171
+ export=True,
172
+ model_save_dir=model_save_dir,
173
+ rbln_batch_size=unet_batch_size,
174
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
175
+ rbln_vae_scale_factor=model.vae_scale_factor,
176
+ **rbln_config_kwargs,
177
+ **rbln_constructor_kwargs,
178
+ )
179
+ )
180
+ controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
181
+ controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
182
+ else:
183
+ controlnet = RBLNControlNetModel.from_pretrained(
184
+ model_id=controlnet.config._name_or_path,
185
+ subfolder="controlnet",
186
+ export=True,
187
+ model_save_dir=model_save_dir,
188
+ rbln_batch_size=unet_batch_size,
189
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
190
+ rbln_vae_scale_factor=model.vae_scale_factor,
191
+ **rbln_config_kwargs,
192
+ **rbln_constructor_kwargs,
193
+ )
194
+ controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
195
+
196
+ if model_save_dir is not None:
197
+ # To skip saving original pytorch modules
198
+ del (model.vae, model.text_encoder, model.unet, model.controlnet)
199
+
200
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
201
+ # So config must be saved again, later.
202
+ model.save_pretrained(model_save_dir)
203
+
204
+ # replace modules
157
205
  model.vae = vae
158
206
  model.text_encoder = text_encoder
159
207
  model.unet = unet
160
208
  model.text_encoder_2 = text_encoder_2
161
209
  model.controlnet = controlnet
162
210
 
211
+ # update config to be able to load from file
163
212
  update_dict = {
164
213
  "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
165
214
  "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
@@ -169,14 +218,24 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
169
218
  }
170
219
  model.register_to_config(**update_dict)
171
220
 
172
- model.models = [
173
- vae.model[0],
174
- vae.model[1],
175
- unet.model[0],
176
- text_encoder.model[0],
177
- text_encoder_2.model[0],
178
- controlnet.model[0],
179
- ]
221
+ if model_save_dir is not None:
222
+ # overwrite to replace incorrect config
223
+ model.save_config(model_save_dir)
224
+
225
+ # use for CI to access each compiled model
226
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
227
+ model.compiled_models = [
228
+ vae.compiled_models[0],
229
+ vae.compiled_models[1],
230
+ text_encoder.compiled_models[0],
231
+ text_encoder_2.compiled_models[0],
232
+ unet.compiled_models[0],
233
+ ]
234
+ if isinstance(controlnet, RBLNMultiControlNetModel):
235
+ for c_model in controlnet.nets:
236
+ model.compiled_models.append(c_model.compiled_models[0])
237
+ else:
238
+ model.compiled_models.append(controlnet.compiled_models[0])
180
239
 
181
240
  return model
182
241
 
@@ -24,7 +24,9 @@
24
24
  from .modeling_base import (
25
25
  RBLNModelForAudioClassification,
26
26
  RBLNModelForImageClassification,
27
+ RBLNModelForMaskedLM,
27
28
  RBLNModelForQuestionAnswering,
29
+ RBLNModelForSequenceClassification,
28
30
  )
29
31
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
30
32
 
@@ -47,3 +49,15 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
47
49
 
48
50
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
49
51
  pass
52
+
53
+
54
+ class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
55
+ pass
56
+
57
+
58
+ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
59
+ pass
60
+
61
+
62
+ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
63
+ pass
@@ -39,7 +39,9 @@ from transformers import (
39
39
  AutoModel,
40
40
  AutoModelForAudioClassification,
41
41
  AutoModelForImageClassification,
42
+ AutoModelForMaskedLM,
42
43
  AutoModelForQuestionAnswering,
44
+ AutoModelForSequenceClassification,
43
45
  GenerationConfig,
44
46
  PretrainedConfig,
45
47
  )
@@ -748,3 +750,111 @@ class RBLNModelForAudioClassification(RBLNModel):
748
750
  )
749
751
 
750
752
  return rbln_config
753
+
754
+
755
+ class RBLNModelForSequenceClassification(RBLNModel):
756
+ """
757
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method
758
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
759
+
760
+ A class to convert and run pre-trained transformers based SequenceClassification models on RBLN devices.
761
+ It implements the methods to convert a pre-trained transformers SequenceClassification model into a RBLN transformer model by:
762
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
763
+ - compiling the resulting graph using the RBLN compiler.
764
+
765
+ Currently, this model class supports the 'XLMRoberta' and 'Roberta' model from the transformers library. Future updates may include support for additional model types.
766
+ """
767
+
768
+ model_type = "rbln_model"
769
+ auto_model_class = AutoModelForSequenceClassification
770
+
771
+ @classmethod
772
+ def _get_rbln_config(
773
+ cls,
774
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
775
+ model_config: Optional["PretrainedConfig"] = None,
776
+ rbln_max_seq_len: Optional[int] = None,
777
+ rbln_model_input_names: Optional[List[str]] = None,
778
+ rbln_batch_size: Optional[int] = None,
779
+ ) -> RBLNConfig:
780
+
781
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
782
+ model_config, "max_position_embeddings", None
783
+ )
784
+
785
+ if rbln_max_seq_len is None:
786
+ rbln_max_seq_len = max_position_embeddings
787
+ if rbln_max_seq_len is None:
788
+ for tokenizer in preprocessors:
789
+ if hasattr(tokenizer, "model_max_length"):
790
+ rbln_max_seq_len = tokenizer.model_max_length
791
+ break
792
+ if rbln_max_seq_len is None:
793
+ raise ValueError("`rbln_max_seq_len` should be specified!")
794
+
795
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
796
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
797
+
798
+ if rbln_model_input_names is None:
799
+ # These are BERT's inputs
800
+ rbln_model_input_names = ["input_ids", "attention_mask"]
801
+
802
+ if rbln_batch_size is None:
803
+ rbln_batch_size = 1
804
+ input_info = [
805
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
806
+ for model_input_name in rbln_model_input_names
807
+ ]
808
+
809
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
810
+ rbln_runtime_config.batch_size = rbln_batch_size
811
+ meta = {"rbln_max_seq_len": rbln_max_seq_len}
812
+
813
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
814
+
815
+ class RBLNModelForMaskedLM(RBLNModel):
816
+ model_type = "rbln_model"
817
+ auto_model_class = AutoModelForMaskedLM
818
+
819
+ @classmethod
820
+ def _get_rbln_config(
821
+ cls,
822
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
823
+ model_config: Optional["PretrainedConfig"] = None,
824
+ rbln_max_seq_len: Optional[int] = None,
825
+ rbln_model_input_names: Optional[List[str]] = None,
826
+ rbln_batch_size: Optional[int] = None,
827
+ ) -> RBLNConfig:
828
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
829
+ model_config, "max_position_embeddings", None
830
+ )
831
+
832
+ if rbln_max_seq_len is None:
833
+ rbln_max_seq_len = max_position_embeddings
834
+ if rbln_max_seq_len is None:
835
+ for tokenizer in preprocessors:
836
+ if hasattr(tokenizer, "model_max_length"):
837
+ rbln_max_seq_len = tokenizer.model_max_length
838
+ break
839
+ if rbln_max_seq_len is None:
840
+ raise ValueError("`rbln_max_seq_len` should be specified!")
841
+
842
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
843
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
844
+
845
+ if rbln_model_input_names is None:
846
+ # These are BERT's inputs
847
+ rbln_model_input_names = ["input_ids", "attention_mask"]
848
+
849
+ if rbln_batch_size is None:
850
+ rbln_batch_size = 1
851
+ input_info = [
852
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
853
+ for model_input_name in rbln_model_input_names
854
+ ]
855
+
856
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
857
+ rbln_runtime_config.batch_size = rbln_batch_size
858
+ meta = {"rbln_max_seq_len": rbln_max_seq_len}
859
+
860
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
@@ -27,30 +27,36 @@ from transformers.utils import _LazyModule
27
27
 
28
28
 
29
29
  _import_structure = {
30
+ "cache_utils": ["RebelDynamicCache"],
30
31
  "generation": ["BatchTextIteratorStreamer"],
31
32
  "models": [
32
33
  "RBLNCLIPTextModel",
33
34
  "RBLNCLIPTextModelWithProjection",
34
35
  "RBLNDPTForDepthEstimation",
36
+ "RBLNGemmaForCausalLM",
35
37
  "RBLNGPT2LMHeadModel",
36
38
  "RBLNWav2Vec2ForCTC",
37
39
  "RBLNWhisperForConditionalGeneration",
38
40
  "RBLNLlamaForCausalLM",
39
41
  "RBLNMidmLMHeadModel",
42
+ "RBLNXLMRobertaModel"
40
43
  ],
41
44
  }
42
45
 
43
46
  if TYPE_CHECKING:
47
+ from .cache_utils import RebelDynamicCache
44
48
  from .generation import BatchTextIteratorStreamer
45
49
  from .models import (
46
50
  RBLNCLIPTextModel,
47
51
  RBLNCLIPTextModelWithProjection,
48
52
  RBLNDPTForDepthEstimation,
53
+ RBLNGemmaForCausalLM,
49
54
  RBLNGPT2LMHeadModel,
50
55
  RBLNLlamaForCausalLM,
51
56
  RBLNMidmLMHeadModel,
52
57
  RBLNWav2Vec2ForCTC,
53
58
  RBLNWhisperForConditionalGeneration,
59
+ RBLNXLMRobertaModel,
54
60
  )
55
61
  else:
56
62
  import sys
@@ -0,0 +1,111 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from transformers.cache_utils import DynamicCache
5
+
6
+
7
+ class RebelDynamicCache(DynamicCache):
8
+ """
9
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
10
+
11
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
12
+ `[batch_size, num_heads, seq_len, head_dim]`.
13
+ """
14
+
15
+ def __init__(self, current_steps) -> None:
16
+ super().__init__()
17
+ self.current_steps = current_steps
18
+
19
+ def assign(
20
+ self,
21
+ key_states: torch.Tensor,
22
+ value_states: torch.Tensor,
23
+ layer_idx: int,
24
+ ) -> None:
25
+ self.key_cache[layer_idx] = key_states.squeeze(2)
26
+ self.value_cache[layer_idx] = value_states.squeeze(2)
27
+
28
+ def update(
29
+ self,
30
+ key_states: torch.Tensor,
31
+ value_states: torch.Tensor,
32
+ layer_idx: int,
33
+ batch_idx: int,
34
+ read_first_step: Optional[bool] = False,
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ """
37
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` and the batch 'batch_inx'
38
+ based on self.current_step,
39
+ """
40
+ current_step = self.current_steps[0 if read_first_step else batch_idx]
41
+ kend = current_step + key_states.shape[-2]
42
+ vend = current_step + value_states.shape[-2]
43
+ update_key_states = (
44
+ self.key_cache[layer_idx][batch_idx]
45
+ .unsqueeze(0)
46
+ .unsqueeze(2)
47
+ .slice_scatter(key_states, dim=-2, start=current_step, end=kend)
48
+ )
49
+ update_value_states = (
50
+ self.value_cache[layer_idx][batch_idx]
51
+ .unsqueeze(0)
52
+ .unsqueeze(2)
53
+ .slice_scatter(value_states, dim=-2, start=current_step, end=vend)
54
+ )
55
+
56
+ return update_key_states, update_value_states
57
+
58
+ @classmethod
59
+ def from_input_format(cls, position_ids, num_hidden_layer, *past_key_values) -> "DynamicCache":
60
+ """Converts a cache in the rbln cache format (list of past_kv) into an equivalent `DynamicCache`."""
61
+
62
+ batch, _ = position_ids.shape
63
+ current_steps = [position_ids[b][0] for b in range(batch)]
64
+
65
+ assert len(current_steps) == batch
66
+ cache = cls(current_steps)
67
+
68
+ for layer_idx in range(num_hidden_layer):
69
+ key_states = past_key_values[layer_idx * 2]
70
+ value_states = past_key_values[layer_idx * 2 + 1]
71
+ cache.key_cache.append(key_states)
72
+ cache.value_cache.append(value_states)
73
+
74
+ return cache
75
+
76
+
77
+ class RebelDynamicCache_4D(RebelDynamicCache):
78
+ def assign(
79
+ self,
80
+ keys: torch.Tensor,
81
+ values: torch.Tensor,
82
+ layer_idx: int,
83
+ ) -> None:
84
+ self.key_cache[layer_idx] = keys
85
+ self.value_cache[layer_idx] = values
86
+
87
+ def update(
88
+ self,
89
+ keys: torch.Tensor,
90
+ values: torch.Tensor,
91
+ layer_idx: int,
92
+ batch_idx: int,
93
+ read_first_step: Optional[bool] = False,
94
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
95
+ """
96
+ Updates the cache with the new `keys` and `values` for the layer `layer_idx` and the batch 'batch_inx'
97
+ based on self.current_step,
98
+ """
99
+ current_step = self.current_steps[0 if read_first_step else batch_idx]
100
+ kend = current_step + keys.shape[-2]
101
+ vend = current_step + values.shape[-2]
102
+ update_keys = (
103
+ self.key_cache[layer_idx][batch_idx].unsqueeze(0).slice_scatter(keys, dim=-2, start=current_step, end=kend)
104
+ )
105
+ update_values = (
106
+ self.value_cache[layer_idx][batch_idx]
107
+ .unsqueeze(0)
108
+ .slice_scatter(values, dim=-2, start=current_step, end=vend)
109
+ )
110
+
111
+ return update_keys, update_values
@@ -32,7 +32,6 @@ class RBLNGenerationMixin:
32
32
  generation_config: Optional[GenerationConfig] = None, # thkim change for 4.41.0
33
33
  **model_kwargs,
34
34
  ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
35
-
36
35
  ###################### thkim change for 4.41.0 ############################
37
36
  if generation_config is not None:
38
37
  pad_token_id = generation_config.pad_token_id
@@ -216,7 +215,6 @@ class RBLNGenerationMixin:
216
215
  do_sample: Optional[bool] = True,
217
216
  **model_kwargs,
218
217
  ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
219
-
220
218
  ###################### thkim change for 4.41.0 ############################
221
219
  if generation_config is not None:
222
220
  pad_token_id = generation_config.pad_token_id
@@ -23,8 +23,10 @@
23
23
 
24
24
  from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
25
25
  from .dpt import RBLNDPTForDepthEstimation
26
+ from .gemma import RBLNGemmaForCausalLM
26
27
  from .gpt2 import RBLNGPT2LMHeadModel
27
28
  from .llama import RBLNLlamaForCausalLM
28
29
  from .midm import RBLNMidmLMHeadModel
29
30
  from .wav2vec2 import RBLNWav2Vec2ForCTC
30
31
  from .whisper import RBLNWhisperForConditionalGeneration
32
+ from .xlm_roberta import RBLNXLMRobertaModel
@@ -56,7 +56,6 @@ class _BartAttention(BartAttention):
56
56
  cache_position: torch.Tensor,
57
57
  key_value_states: Optional[torch.Tensor] = None,
58
58
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
59
-
60
59
  bsz, tgt_len, _ = hidden_states.size()
61
60
  is_cross_attention = key_value_states is not None
62
61
 
@@ -111,7 +110,6 @@ class _BartSdpaAttention(BartSdpaAttention):
111
110
  cache_position: torch.Tensor,
112
111
  key_value_states: Optional[torch.Tensor] = None,
113
112
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
114
-
115
113
  bsz, tgt_len, _ = hidden_states.size()
116
114
  is_cross_attention = key_value_states is not None
117
115
 
@@ -166,7 +164,6 @@ class _BartDecoderLayer(BartDecoderLayer):
166
164
  cache_position: torch.Tensor,
167
165
  attn_impl: str = "eager",
168
166
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
169
-
170
167
  # Self Attention Block
171
168
  residual = hidden_states
172
169
  self_attn_past_key_value = past_key_value[:2]
@@ -218,7 +215,6 @@ class _BartDecoder(BartDecoder):
218
215
  cache_position: torch.Tensor,
219
216
  attn_impl: str = "eager",
220
217
  ):
221
-
222
218
  # embedding
223
219
  positions_idx = cache_position + self.embed_positions.offset
224
220
  positions = self.embed_positions.weight[positions_idx]
@@ -284,7 +280,6 @@ class BartDecoderWrapper(torch.nn.Module):
284
280
  self_kv_cache: torch.Tensor,
285
281
  cross_kv_cache: torch.Tensor,
286
282
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
287
-
288
283
  # prepare past_key_values
289
284
  kv_cache = ()
290
285
  for i in range(0, self.num_layers * 2, 2):
@@ -0,0 +1,36 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .decoderonly_architecture import (
25
+ DecoderOnlyAttention,
26
+ DecoderOnlyDecoderLayer,
27
+ DecoderOnlyModel,
28
+ DecoderOnlyWrapper,
29
+ DynamicNTKScalingRotaryEmbedding,
30
+ LinearScalingRotaryEmbedding,
31
+ RotaryEmbedding,
32
+ apply_rotary_pos_emb,
33
+ rotate_half,
34
+ slice_and_unsqueeze_cos_sin,
35
+ )
36
+ from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM