optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +21 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +282 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
- optimum_rbln-0.1.8.dist-info/RECORD +73 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
- optimum_rbln-0.1.4.dist-info/RECORD +0 -63
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_seq2seq.py
CHANGED
@@ -23,13 +23,10 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
26
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
29
27
|
|
30
28
|
import rebel
|
31
29
|
import torch
|
32
|
-
from optimum.exporters import TasksManager
|
33
30
|
from transformers import (
|
34
31
|
AutoModelForSeq2SeqLM,
|
35
32
|
BartConfig,
|
@@ -39,12 +36,11 @@ from transformers import (
|
|
39
36
|
)
|
40
37
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
41
38
|
|
42
|
-
from .modeling_base import
|
39
|
+
from .modeling_base import RBLNModel
|
43
40
|
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
44
41
|
from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
|
45
42
|
from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
|
46
43
|
from .utils.runtime_utils import RBLNPytorchRuntime
|
47
|
-
from .utils.save_utils import maybe_save_preprocessors
|
48
44
|
|
49
45
|
|
50
46
|
logger = logging.getLogger(__name__)
|
@@ -75,7 +71,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
75
71
|
return Seq2SeqLMOutput(logits=outputs)
|
76
72
|
|
77
73
|
|
78
|
-
class RBLNModelForSeq2SeqLM(
|
74
|
+
class RBLNModelForSeq2SeqLM(RBLNModel):
|
79
75
|
"""
|
80
76
|
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
81
77
|
This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -88,7 +84,6 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
88
84
|
Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
|
89
85
|
"""
|
90
86
|
|
91
|
-
model_type = "rbln_model"
|
92
87
|
auto_model_class = AutoModelForSeq2SeqLM
|
93
88
|
|
94
89
|
def __post_init__(self, **kwargs):
|
@@ -97,8 +92,8 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
97
92
|
self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
|
98
93
|
self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
|
99
94
|
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
100
|
-
self.encoder = RBLNRuntimeEncoder(runtime=self.
|
101
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.
|
95
|
+
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
|
96
|
+
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
102
97
|
|
103
98
|
def can_generate(self):
|
104
99
|
return True
|
@@ -149,74 +144,18 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
149
144
|
}
|
150
145
|
|
151
146
|
@classmethod
|
152
|
-
def
|
153
|
-
cls,
|
154
|
-
model_id: str,
|
155
|
-
config: "PretrainedConfig",
|
156
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
157
|
-
revision: Optional[str] = None,
|
158
|
-
force_download: bool = False,
|
159
|
-
cache_dir: Optional[str] = None,
|
160
|
-
subfolder: str = "",
|
161
|
-
local_files_only: bool = False,
|
162
|
-
trust_remote_code: bool = False,
|
163
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
164
|
-
**kwargs,
|
165
|
-
) -> "AutoModelForSeq2SeqLM":
|
166
|
-
"""
|
167
|
-
Exports a vanilla Transformers model into a rbln-compiled Module.
|
168
|
-
"""
|
169
|
-
task = kwargs.pop("task", None)
|
170
|
-
if task is None:
|
171
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
172
|
-
|
173
|
-
if model_save_dir is None:
|
174
|
-
save_dir = TemporaryDirectory()
|
175
|
-
save_dir_path = Path(save_dir.name)
|
176
|
-
else:
|
177
|
-
save_dir = model_save_dir
|
178
|
-
if isinstance(save_dir, TemporaryDirectory):
|
179
|
-
save_dir_path = Path(model_save_dir.name)
|
180
|
-
else:
|
181
|
-
save_dir_path = Path(model_save_dir)
|
182
|
-
save_dir_path.mkdir(exist_ok=True)
|
183
|
-
|
147
|
+
def update_kwargs(cls, kwargs):
|
184
148
|
kwargs.update(
|
185
149
|
{
|
186
150
|
"torchscript": True,
|
187
151
|
"return_dict": False,
|
188
|
-
"use_cache":
|
152
|
+
"use_cache": True,
|
189
153
|
}
|
190
154
|
)
|
155
|
+
return kwargs
|
191
156
|
|
192
|
-
|
193
|
-
|
194
|
-
model: AutoModelForSeq2SeqLM = TasksManager.get_model_from_task(
|
195
|
-
task=task,
|
196
|
-
model_name_or_path=model_id,
|
197
|
-
subfolder=subfolder,
|
198
|
-
revision=revision,
|
199
|
-
framework="pt",
|
200
|
-
cache_dir=cache_dir,
|
201
|
-
use_auth_token=use_auth_token,
|
202
|
-
local_files_only=local_files_only,
|
203
|
-
force_download=force_download,
|
204
|
-
trust_remote_code=trust_remote_code,
|
205
|
-
**kwargs,
|
206
|
-
)
|
207
|
-
|
208
|
-
if config is None:
|
209
|
-
config = model.config
|
210
|
-
|
211
|
-
config.save_pretrained(save_dir_path)
|
212
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
213
|
-
|
214
|
-
# Get compilation arguments
|
215
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
216
|
-
rbln_config = cls.get_rbln_config(
|
217
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
218
|
-
)
|
219
|
-
|
157
|
+
@classmethod
|
158
|
+
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
220
159
|
def optimized_models(model):
|
221
160
|
if isinstance(model, T5ForConditionalGeneration):
|
222
161
|
encoder_model = T5EncoderWrapper(model).eval()
|
@@ -229,67 +168,54 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
229
168
|
|
230
169
|
return encoder_model, decoder_model
|
231
170
|
|
232
|
-
|
233
|
-
wrapped_encoder, wrapped_decoder = optimized_models(model)
|
171
|
+
wrapped_encoder, wrapped_decoder = optimized_models(model)
|
234
172
|
|
235
|
-
|
236
|
-
|
237
|
-
|
173
|
+
wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
|
174
|
+
wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
|
175
|
+
wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
238
176
|
|
239
|
-
|
240
|
-
|
241
|
-
|
177
|
+
wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
|
178
|
+
wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
|
179
|
+
wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
242
180
|
|
243
|
-
|
244
|
-
|
181
|
+
enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
182
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
245
183
|
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
)
|
266
|
-
dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
267
|
-
|
268
|
-
connections = [
|
269
|
-
(enc_ir.outputs[0], dec_ir.inputs[5]),
|
270
|
-
(dec_ir.outputs[1], dec_ir.inputs[4]),
|
271
|
-
]
|
272
|
-
compiled_model = rebel.compile(
|
273
|
-
enc_ir,
|
274
|
-
dec_ir,
|
275
|
-
connections=connections,
|
276
|
-
fusion=enc_rbln_runtime_config.fusion,
|
277
|
-
npu=enc_rbln_runtime_config.npu,
|
278
|
-
tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
|
279
|
-
)
|
280
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
281
|
-
|
282
|
-
compile()
|
283
|
-
|
284
|
-
rbln_config.save(save_dir_path)
|
285
|
-
|
286
|
-
return cls._from_pretrained(
|
287
|
-
model_id=save_dir_path,
|
288
|
-
config=config,
|
289
|
-
model_save_dir=save_dir,
|
290
|
-
**rbln_constructor_kwargs,
|
291
|
-
**kwargs,
|
184
|
+
if isinstance(model, T5ForConditionalGeneration):
|
185
|
+
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
|
186
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
|
187
|
+
else:
|
188
|
+
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
|
189
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
190
|
+
|
191
|
+
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
|
192
|
+
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
193
|
+
|
194
|
+
enc_ir = rebel.torchscript_to_ir(
|
195
|
+
enc_scripted_model,
|
196
|
+
input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
|
197
|
+
name=enc_rbln_runtime_config.rbln_mod_name,
|
198
|
+
)
|
199
|
+
dec_ir = rebel.torchscript_to_ir(
|
200
|
+
dec_scripted_model,
|
201
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
202
|
+
name=dec_rbln_runtime_config.rbln_mod_name,
|
292
203
|
)
|
204
|
+
dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
205
|
+
|
206
|
+
connections = [
|
207
|
+
(enc_ir.outputs[0], dec_ir.inputs[5]),
|
208
|
+
(dec_ir.outputs[1], dec_ir.inputs[4]),
|
209
|
+
]
|
210
|
+
compiled_model = rebel.compile(
|
211
|
+
enc_ir,
|
212
|
+
dec_ir,
|
213
|
+
connections=connections,
|
214
|
+
fusion=enc_rbln_runtime_config.fusion,
|
215
|
+
npu=enc_rbln_runtime_config.npu,
|
216
|
+
tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
|
217
|
+
)
|
218
|
+
return compiled_model
|
293
219
|
|
294
220
|
@classmethod
|
295
221
|
def _get_rbln_config(
|
@@ -411,11 +337,14 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
411
337
|
|
412
338
|
return rbln_config
|
413
339
|
|
414
|
-
|
340
|
+
@classmethod
|
341
|
+
def _create_runtimes(
|
342
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
343
|
+
) -> List[rebel.Runtime]:
|
415
344
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
416
345
|
return [
|
417
|
-
|
418
|
-
|
346
|
+
compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
|
347
|
+
compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
419
348
|
]
|
420
349
|
|
421
350
|
def forward(
|
@@ -436,9 +365,6 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
436
365
|
|
437
366
|
return Seq2SeqLMOutput(logits=lm_logits)
|
438
367
|
|
439
|
-
def __repr__(self):
|
440
|
-
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
441
|
-
|
442
368
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
443
369
|
self,
|
444
370
|
inputs_tensor: torch.Tensor,
|
@@ -27,28 +27,36 @@ from transformers.utils import _LazyModule
|
|
27
27
|
|
28
28
|
|
29
29
|
_import_structure = {
|
30
|
+
"cache_utils": ["RebelDynamicCache"],
|
30
31
|
"generation": ["BatchTextIteratorStreamer"],
|
31
32
|
"models": [
|
32
33
|
"RBLNCLIPTextModel",
|
33
34
|
"RBLNCLIPTextModelWithProjection",
|
35
|
+
"RBLNDPTForDepthEstimation",
|
36
|
+
"RBLNGemmaForCausalLM",
|
34
37
|
"RBLNGPT2LMHeadModel",
|
35
38
|
"RBLNWav2Vec2ForCTC",
|
36
39
|
"RBLNWhisperForConditionalGeneration",
|
37
40
|
"RBLNLlamaForCausalLM",
|
38
41
|
"RBLNMidmLMHeadModel",
|
42
|
+
"RBLNXLMRobertaModel"
|
39
43
|
],
|
40
44
|
}
|
41
45
|
|
42
46
|
if TYPE_CHECKING:
|
47
|
+
from .cache_utils import RebelDynamicCache
|
43
48
|
from .generation import BatchTextIteratorStreamer
|
44
49
|
from .models import (
|
45
50
|
RBLNCLIPTextModel,
|
46
51
|
RBLNCLIPTextModelWithProjection,
|
52
|
+
RBLNDPTForDepthEstimation,
|
53
|
+
RBLNGemmaForCausalLM,
|
47
54
|
RBLNGPT2LMHeadModel,
|
48
55
|
RBLNLlamaForCausalLM,
|
49
56
|
RBLNMidmLMHeadModel,
|
50
57
|
RBLNWav2Vec2ForCTC,
|
51
58
|
RBLNWhisperForConditionalGeneration,
|
59
|
+
RBLNXLMRobertaModel,
|
52
60
|
)
|
53
61
|
else:
|
54
62
|
import sys
|
@@ -0,0 +1,111 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from transformers.cache_utils import DynamicCache
|
5
|
+
|
6
|
+
|
7
|
+
class RebelDynamicCache(DynamicCache):
|
8
|
+
"""
|
9
|
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
10
|
+
|
11
|
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
12
|
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, current_steps) -> None:
|
16
|
+
super().__init__()
|
17
|
+
self.current_steps = current_steps
|
18
|
+
|
19
|
+
def assign(
|
20
|
+
self,
|
21
|
+
key_states: torch.Tensor,
|
22
|
+
value_states: torch.Tensor,
|
23
|
+
layer_idx: int,
|
24
|
+
) -> None:
|
25
|
+
self.key_cache[layer_idx] = key_states.squeeze(2)
|
26
|
+
self.value_cache[layer_idx] = value_states.squeeze(2)
|
27
|
+
|
28
|
+
def update(
|
29
|
+
self,
|
30
|
+
key_states: torch.Tensor,
|
31
|
+
value_states: torch.Tensor,
|
32
|
+
layer_idx: int,
|
33
|
+
batch_idx: int,
|
34
|
+
read_first_step: Optional[bool] = False,
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
+
"""
|
37
|
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx` and the batch 'batch_inx'
|
38
|
+
based on self.current_step,
|
39
|
+
"""
|
40
|
+
current_step = self.current_steps[0 if read_first_step else batch_idx]
|
41
|
+
kend = current_step + key_states.shape[-2]
|
42
|
+
vend = current_step + value_states.shape[-2]
|
43
|
+
update_key_states = (
|
44
|
+
self.key_cache[layer_idx][batch_idx]
|
45
|
+
.unsqueeze(0)
|
46
|
+
.unsqueeze(2)
|
47
|
+
.slice_scatter(key_states, dim=-2, start=current_step, end=kend)
|
48
|
+
)
|
49
|
+
update_value_states = (
|
50
|
+
self.value_cache[layer_idx][batch_idx]
|
51
|
+
.unsqueeze(0)
|
52
|
+
.unsqueeze(2)
|
53
|
+
.slice_scatter(value_states, dim=-2, start=current_step, end=vend)
|
54
|
+
)
|
55
|
+
|
56
|
+
return update_key_states, update_value_states
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def from_input_format(cls, position_ids, num_hidden_layer, *past_key_values) -> "DynamicCache":
|
60
|
+
"""Converts a cache in the rbln cache format (list of past_kv) into an equivalent `DynamicCache`."""
|
61
|
+
|
62
|
+
batch, _ = position_ids.shape
|
63
|
+
current_steps = [position_ids[b][0] for b in range(batch)]
|
64
|
+
|
65
|
+
assert len(current_steps) == batch
|
66
|
+
cache = cls(current_steps)
|
67
|
+
|
68
|
+
for layer_idx in range(num_hidden_layer):
|
69
|
+
key_states = past_key_values[layer_idx * 2]
|
70
|
+
value_states = past_key_values[layer_idx * 2 + 1]
|
71
|
+
cache.key_cache.append(key_states)
|
72
|
+
cache.value_cache.append(value_states)
|
73
|
+
|
74
|
+
return cache
|
75
|
+
|
76
|
+
|
77
|
+
class RebelDynamicCache_4D(RebelDynamicCache):
|
78
|
+
def assign(
|
79
|
+
self,
|
80
|
+
keys: torch.Tensor,
|
81
|
+
values: torch.Tensor,
|
82
|
+
layer_idx: int,
|
83
|
+
) -> None:
|
84
|
+
self.key_cache[layer_idx] = keys
|
85
|
+
self.value_cache[layer_idx] = values
|
86
|
+
|
87
|
+
def update(
|
88
|
+
self,
|
89
|
+
keys: torch.Tensor,
|
90
|
+
values: torch.Tensor,
|
91
|
+
layer_idx: int,
|
92
|
+
batch_idx: int,
|
93
|
+
read_first_step: Optional[bool] = False,
|
94
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
95
|
+
"""
|
96
|
+
Updates the cache with the new `keys` and `values` for the layer `layer_idx` and the batch 'batch_inx'
|
97
|
+
based on self.current_step,
|
98
|
+
"""
|
99
|
+
current_step = self.current_steps[0 if read_first_step else batch_idx]
|
100
|
+
kend = current_step + keys.shape[-2]
|
101
|
+
vend = current_step + values.shape[-2]
|
102
|
+
update_keys = (
|
103
|
+
self.key_cache[layer_idx][batch_idx].unsqueeze(0).slice_scatter(keys, dim=-2, start=current_step, end=kend)
|
104
|
+
)
|
105
|
+
update_values = (
|
106
|
+
self.value_cache[layer_idx][batch_idx]
|
107
|
+
.unsqueeze(0)
|
108
|
+
.slice_scatter(values, dim=-2, start=current_step, end=vend)
|
109
|
+
)
|
110
|
+
|
111
|
+
return update_keys, update_values
|
@@ -32,7 +32,6 @@ class RBLNGenerationMixin:
|
|
32
32
|
generation_config: Optional[GenerationConfig] = None, # thkim change for 4.41.0
|
33
33
|
**model_kwargs,
|
34
34
|
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
35
|
-
|
36
35
|
###################### thkim change for 4.41.0 ############################
|
37
36
|
if generation_config is not None:
|
38
37
|
pad_token_id = generation_config.pad_token_id
|
@@ -216,7 +215,6 @@ class RBLNGenerationMixin:
|
|
216
215
|
do_sample: Optional[bool] = True,
|
217
216
|
**model_kwargs,
|
218
217
|
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
219
|
-
|
220
218
|
###################### thkim change for 4.41.0 ############################
|
221
219
|
if generation_config is not None:
|
222
220
|
pad_token_id = generation_config.pad_token_id
|
@@ -22,8 +22,11 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
25
|
+
from .dpt import RBLNDPTForDepthEstimation
|
26
|
+
from .gemma import RBLNGemmaForCausalLM
|
25
27
|
from .gpt2 import RBLNGPT2LMHeadModel
|
26
28
|
from .llama import RBLNLlamaForCausalLM
|
27
29
|
from .midm import RBLNMidmLMHeadModel
|
28
30
|
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
29
31
|
from .whisper import RBLNWhisperForConditionalGeneration
|
32
|
+
from .xlm_roberta import RBLNXLMRobertaModel
|
@@ -56,7 +56,6 @@ class _BartAttention(BartAttention):
|
|
56
56
|
cache_position: torch.Tensor,
|
57
57
|
key_value_states: Optional[torch.Tensor] = None,
|
58
58
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
59
|
-
|
60
59
|
bsz, tgt_len, _ = hidden_states.size()
|
61
60
|
is_cross_attention = key_value_states is not None
|
62
61
|
|
@@ -111,7 +110,6 @@ class _BartSdpaAttention(BartSdpaAttention):
|
|
111
110
|
cache_position: torch.Tensor,
|
112
111
|
key_value_states: Optional[torch.Tensor] = None,
|
113
112
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
114
|
-
|
115
113
|
bsz, tgt_len, _ = hidden_states.size()
|
116
114
|
is_cross_attention = key_value_states is not None
|
117
115
|
|
@@ -166,7 +164,6 @@ class _BartDecoderLayer(BartDecoderLayer):
|
|
166
164
|
cache_position: torch.Tensor,
|
167
165
|
attn_impl: str = "eager",
|
168
166
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
169
|
-
|
170
167
|
# Self Attention Block
|
171
168
|
residual = hidden_states
|
172
169
|
self_attn_past_key_value = past_key_value[:2]
|
@@ -218,7 +215,6 @@ class _BartDecoder(BartDecoder):
|
|
218
215
|
cache_position: torch.Tensor,
|
219
216
|
attn_impl: str = "eager",
|
220
217
|
):
|
221
|
-
|
222
218
|
# embedding
|
223
219
|
positions_idx = cache_position + self.embed_positions.offset
|
224
220
|
positions = self.embed_positions.weight[positions_idx]
|
@@ -284,7 +280,6 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
284
280
|
self_kv_cache: torch.Tensor,
|
285
281
|
cross_kv_cache: torch.Tensor,
|
286
282
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
287
|
-
|
288
283
|
# prepare past_key_values
|
289
284
|
kv_cache = ()
|
290
285
|
for i in range(0, self.num_layers * 2, 2):
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .decoderonly_architecture import (
|
25
|
+
DecoderOnlyAttention,
|
26
|
+
DecoderOnlyDecoderLayer,
|
27
|
+
DecoderOnlyModel,
|
28
|
+
DecoderOnlyWrapper,
|
29
|
+
DynamicNTKScalingRotaryEmbedding,
|
30
|
+
LinearScalingRotaryEmbedding,
|
31
|
+
RotaryEmbedding,
|
32
|
+
apply_rotary_pos_emb,
|
33
|
+
rotate_half,
|
34
|
+
slice_and_unsqueeze_cos_sin,
|
35
|
+
)
|
36
|
+
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|