optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +17 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/models/controlnet.py +7 -3
- optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +19 -1
- optimum/rbln/modeling_base.py +162 -18
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
- optimum/rbln/utils/import_utils.py +1 -4
- optimum/rbln/utils/runtime_utils.py +2 -1
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -23,17 +23,10 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable
|
27
27
|
|
28
|
-
import
|
29
|
-
import
|
30
|
-
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
31
|
-
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
32
|
-
|
33
|
-
from ....modeling_base import RBLNModel
|
34
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
35
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
36
|
-
from ...generation.utils import RBLNGenerationMixin
|
28
|
+
from ....modeling_config import RBLNConfig
|
29
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
37
30
|
from .hf_hub_cached.modeling_midm import MidmLMHeadModel
|
38
31
|
from .midm_architecture import (
|
39
32
|
MidmLMHeadModelWrapper,
|
@@ -41,41 +34,18 @@ from .midm_architecture import (
|
|
41
34
|
|
42
35
|
|
43
36
|
logger = logging.getLogger(__name__)
|
44
|
-
|
45
37
|
if TYPE_CHECKING:
|
46
38
|
from transformers import (
|
47
|
-
|
48
|
-
AutoProcessor,
|
49
|
-
AutoTokenizer,
|
50
|
-
PretrainedConfig,
|
39
|
+
PreTrainedModel,
|
51
40
|
)
|
52
41
|
|
53
42
|
|
54
|
-
class
|
55
|
-
mandatory_members = ["main_input_name"]
|
56
|
-
|
57
|
-
# RBLN_Runtimemodule
|
58
|
-
def forward(
|
59
|
-
self,
|
60
|
-
input_ids: torch.LongTensor = None,
|
61
|
-
attention_mask: torch.LongTensor = None,
|
62
|
-
cache_position: torch.Tensor = None,
|
63
|
-
**kwargs: Dict[str, Any],
|
64
|
-
):
|
65
|
-
logits = super().forward(
|
66
|
-
input_ids=input_ids,
|
67
|
-
attention_mask=attention_mask,
|
68
|
-
cache_position=cache_position,
|
69
|
-
)
|
70
|
-
return logits
|
71
|
-
|
72
|
-
|
73
|
-
class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
43
|
+
class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
74
44
|
"""
|
75
45
|
The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
76
46
|
embeddings).
|
77
47
|
|
78
|
-
This model inherits from [`
|
48
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
|
79
49
|
library implements for all its model.
|
80
50
|
|
81
51
|
It implements the methods to convert a pre-trained transformers Midm model into a RBLN transformer model by:
|
@@ -84,46 +54,10 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
84
54
|
|
85
55
|
"""
|
86
56
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
def __init__(
|
92
|
-
self,
|
93
|
-
models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
|
94
|
-
config: PretrainedConfig = None,
|
95
|
-
preprocessors: Optional[List] = None,
|
96
|
-
rbln_config: Optional[RBLNConfig] = None,
|
97
|
-
rbln_device: Optional[List[int]] = None,
|
98
|
-
rbln_device_map: Optional[Dict[str, int]] = None,
|
99
|
-
**kwargs,
|
100
|
-
):
|
101
|
-
super().__init__(
|
102
|
-
models,
|
103
|
-
config,
|
104
|
-
preprocessors,
|
105
|
-
rbln_config,
|
106
|
-
rbln_device=rbln_device,
|
107
|
-
rbln_device_map=rbln_device_map,
|
108
|
-
**kwargs,
|
109
|
-
)
|
110
|
-
self.batch_size = self.rbln_config.meta["rbln_batch_size"]
|
111
|
-
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
112
|
-
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
113
|
-
|
114
|
-
self.prefill_attention_mask = torch.zeros(
|
115
|
-
self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
116
|
-
)
|
117
|
-
self.causal_mask = 1 - torch.triu(
|
118
|
-
torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
119
|
-
)
|
120
|
-
|
121
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0], main_input_name="input_ids")
|
122
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
123
|
-
self.past_cached_length = 0
|
124
|
-
|
125
|
-
def can_generate(self):
|
126
|
-
return True
|
57
|
+
@classmethod
|
58
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
59
|
+
rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
|
60
|
+
return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
127
61
|
|
128
62
|
def __getattr__(self, __name: str) -> Any:
|
129
63
|
"""This is the key method to implement RBLN-Midm.
|
@@ -139,252 +73,3 @@ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
|
|
139
73
|
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
140
74
|
return redirect(val)
|
141
75
|
return val
|
142
|
-
|
143
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
144
|
-
# TODO(jongho): implement
|
145
|
-
raise NotImplementedError
|
146
|
-
|
147
|
-
@classmethod
|
148
|
-
@torch.inference_mode()
|
149
|
-
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
150
|
-
wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
|
151
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
152
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
153
|
-
|
154
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
155
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
156
|
-
|
157
|
-
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
|
158
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
159
|
-
|
160
|
-
prefill_ir = rebel.torchscript_to_ir(
|
161
|
-
prefill_scripted_model,
|
162
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
163
|
-
)
|
164
|
-
dec_ir = rebel.torchscript_to_ir(
|
165
|
-
dec_scripted_model,
|
166
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
167
|
-
)
|
168
|
-
|
169
|
-
connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
|
170
|
-
|
171
|
-
compiled_model = rebel.compile(
|
172
|
-
prefill_ir,
|
173
|
-
dec_ir,
|
174
|
-
connections=connections,
|
175
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
176
|
-
npu=prefill_rbln_runtime_config.npu,
|
177
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
178
|
-
use_weight_sharing=True,
|
179
|
-
)
|
180
|
-
return compiled_model
|
181
|
-
|
182
|
-
@classmethod
|
183
|
-
def update_kwargs(cls, kwargs):
|
184
|
-
"""
|
185
|
-
Update user-given kwargs to get proper pytorch model.
|
186
|
-
|
187
|
-
For example, `torchscript`=True should be set because torch.jit
|
188
|
-
does not support `transformers` output instances as module output;
|
189
|
-
"""
|
190
|
-
kwargs.update(
|
191
|
-
{
|
192
|
-
"torchscript": True,
|
193
|
-
"return_dict": False,
|
194
|
-
"use_cache": True,
|
195
|
-
"torch_dtype": torch.float32,
|
196
|
-
"_attn_implementation": "eager",
|
197
|
-
}
|
198
|
-
)
|
199
|
-
return kwargs
|
200
|
-
|
201
|
-
@classmethod
|
202
|
-
def get_pytorch_model(
|
203
|
-
cls,
|
204
|
-
model_id: str,
|
205
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
206
|
-
revision: Optional[str] = None,
|
207
|
-
force_download: bool = False,
|
208
|
-
cache_dir: Optional[str] = None,
|
209
|
-
subfolder: str = "",
|
210
|
-
local_files_only: bool = False,
|
211
|
-
trust_remote_code: bool = False,
|
212
|
-
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
213
|
-
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
214
|
-
**kwargs,
|
215
|
-
) -> PreTrainedModel:
|
216
|
-
if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
|
217
|
-
config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
218
|
-
if hf_position_embedding := getattr(config, "max_position_embeddings", None):
|
219
|
-
if hf_position_embedding < rbln_max_seq_len:
|
220
|
-
logger.warning(
|
221
|
-
f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
|
222
|
-
"This may lead to incorrect inferences of the model."
|
223
|
-
)
|
224
|
-
kwargs.update({"max_position_embeddings": rbln_max_seq_len})
|
225
|
-
|
226
|
-
return super().get_pytorch_model(
|
227
|
-
model_id=model_id,
|
228
|
-
use_auth_token=use_auth_token,
|
229
|
-
revision=revision,
|
230
|
-
force_download=force_download,
|
231
|
-
cache_dir=cache_dir,
|
232
|
-
subfolder=subfolder,
|
233
|
-
local_files_only=local_files_only,
|
234
|
-
trust_remote_code=trust_remote_code,
|
235
|
-
rbln_config_kwargs=rbln_config_kwargs,
|
236
|
-
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
237
|
-
ignore_mismatched_sizes=True,
|
238
|
-
**kwargs,
|
239
|
-
)
|
240
|
-
|
241
|
-
@classmethod
|
242
|
-
def _get_rbln_config(
|
243
|
-
cls,
|
244
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
245
|
-
model_config: "PretrainedConfig",
|
246
|
-
rbln_prefill_chunk_size: Optional[int] = 128,
|
247
|
-
rbln_max_seq_len: Optional[int] = None,
|
248
|
-
rbln_batch_size: Optional[int] = None,
|
249
|
-
) -> RBLNConfig:
|
250
|
-
meta = {}
|
251
|
-
if rbln_max_seq_len is None:
|
252
|
-
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
253
|
-
|
254
|
-
if rbln_max_seq_len is None:
|
255
|
-
for tokenizer in preprocessors:
|
256
|
-
if hasattr(tokenizer, "model_max_length"):
|
257
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
258
|
-
break
|
259
|
-
if rbln_max_seq_len is None:
|
260
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
261
|
-
|
262
|
-
if rbln_batch_size is None:
|
263
|
-
rbln_batch_size = 1
|
264
|
-
|
265
|
-
meta["rbln_prefill_chunk_size"] = rbln_prefill_chunk_size
|
266
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
267
|
-
meta["rbln_batch_size"] = rbln_batch_size if rbln_batch_size is not None else 1
|
268
|
-
|
269
|
-
def get_input_info(query_length):
|
270
|
-
input_info = [
|
271
|
-
("input_ids", [rbln_batch_size, query_length], "int64"),
|
272
|
-
("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
273
|
-
(
|
274
|
-
"cache_position",
|
275
|
-
[],
|
276
|
-
"int32",
|
277
|
-
),
|
278
|
-
]
|
279
|
-
input_info.extend(
|
280
|
-
[
|
281
|
-
(
|
282
|
-
f"past_key_values_{i}",
|
283
|
-
[
|
284
|
-
rbln_batch_size,
|
285
|
-
model_config.n_head,
|
286
|
-
rbln_max_seq_len,
|
287
|
-
model_config.hidden_size // model_config.n_head,
|
288
|
-
],
|
289
|
-
"float32",
|
290
|
-
)
|
291
|
-
for i in range(model_config.n_layer * 2)
|
292
|
-
]
|
293
|
-
)
|
294
|
-
return input_info
|
295
|
-
|
296
|
-
# model input info
|
297
|
-
prefill_input_info = get_input_info(query_length=rbln_prefill_chunk_size)
|
298
|
-
dec_input_info = get_input_info(query_length=1)
|
299
|
-
|
300
|
-
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
301
|
-
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
302
|
-
|
303
|
-
dec_rbln_runtime_config.batch_size = rbln_batch_size
|
304
|
-
|
305
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
306
|
-
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
307
|
-
_rbln_meta=meta,
|
308
|
-
)
|
309
|
-
|
310
|
-
return rbln_config
|
311
|
-
|
312
|
-
@classmethod
|
313
|
-
def _create_runtimes(
|
314
|
-
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
315
|
-
) -> List[rebel.Runtime]:
|
316
|
-
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
317
|
-
return [
|
318
|
-
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
319
|
-
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
320
|
-
]
|
321
|
-
|
322
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
323
|
-
batch_size, cur_len = input_ids.shape
|
324
|
-
past_cached_length = past_key_values
|
325
|
-
|
326
|
-
if past_cached_length == 0:
|
327
|
-
mod_len = cur_len % self.prefill_chunk_size
|
328
|
-
self.pad_len = self.prefill_chunk_size - mod_len if mod_len > 0 else 0
|
329
|
-
|
330
|
-
prompt_attn_mask = torch.nn.functional.pad(attention_mask, (self.pad_len, 0), value=0)
|
331
|
-
self.prompt_attn_mask = prompt_attn_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
332
|
-
|
333
|
-
input_ids = torch.nn.functional.pad(input_ids, (self.pad_len, 0), value=0)
|
334
|
-
attention_mask = self.prefill_attention_mask.clone()
|
335
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
336
|
-
|
337
|
-
query_length = cur_len + self.pad_len
|
338
|
-
else:
|
339
|
-
attention_mask = torch.nn.functional.pad(
|
340
|
-
attention_mask, (self.pad_len, self.max_seq_len - cur_len - self.pad_len)
|
341
|
-
)
|
342
|
-
attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
343
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
344
|
-
input_ids = input_ids[:, -1:].contiguous()
|
345
|
-
query_length = 1
|
346
|
-
|
347
|
-
model_inputs = {
|
348
|
-
"input_ids": input_ids,
|
349
|
-
"past_key_values": past_cached_length,
|
350
|
-
"attention_mask": attention_mask,
|
351
|
-
"cache_position": cache_position,
|
352
|
-
"query_length": query_length,
|
353
|
-
}
|
354
|
-
|
355
|
-
return model_inputs
|
356
|
-
|
357
|
-
def forward(
|
358
|
-
self,
|
359
|
-
input_ids: Optional[torch.LongTensor] = None,
|
360
|
-
past_key_values: int = None,
|
361
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
362
|
-
cache_position: Optional[torch.Tensor] = None,
|
363
|
-
query_length: Optional[torch.Tensor] = None,
|
364
|
-
**kwargs,
|
365
|
-
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
366
|
-
past_cached_length = past_key_values
|
367
|
-
|
368
|
-
if past_cached_length is not None:
|
369
|
-
past_cached_length += query_length
|
370
|
-
|
371
|
-
if cache_position == 0:
|
372
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
373
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
374
|
-
attention_mask[:, :, :, :step] = 1
|
375
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
376
|
-
attention_mask[:, :, :, :query_length] *= self.prompt_attn_mask
|
377
|
-
|
378
|
-
output = self.prefill_decoder(
|
379
|
-
input_ids=sliced_input_ids.contiguous(),
|
380
|
-
attention_mask=attention_mask,
|
381
|
-
cache_position=cache_position + step,
|
382
|
-
)
|
383
|
-
cache_position += self.prefill_chunk_size
|
384
|
-
else:
|
385
|
-
output = self.decoder(
|
386
|
-
input_ids=input_ids.contiguous(),
|
387
|
-
attention_mask=attention_mask,
|
388
|
-
cache_position=cache_position,
|
389
|
-
)
|
390
|
-
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_mistral import RBLNMistralForCausalLM
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
|
25
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
26
|
+
|
27
|
+
|
28
|
+
class MistralForCausalLMWrapper(DecoderOnlyWrapper):
|
29
|
+
pass
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
import logging
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable
|
27
|
+
|
28
|
+
from transformers import MistralForCausalLM
|
29
|
+
|
30
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
|
+
from .mistral_architecture import MistralForCausalLMWrapper
|
32
|
+
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from transformers import PreTrainedModel
|
36
|
+
|
37
|
+
from ....modeling_config import RBLNConfig
|
38
|
+
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
44
|
+
"""
|
45
|
+
The Llama Model transformer with a language modeling head (linear layer) on top.
|
46
|
+
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
47
|
+
|
48
|
+
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
49
|
+
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
50
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
51
|
+
- compiling the resulting graph using the RBLN compiler.
|
52
|
+
"""
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
+
rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
|
57
|
+
return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
|
58
|
+
|
59
|
+
def __getattr__(self, __name: str) -> Any:
|
60
|
+
def redirect(func):
|
61
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
62
|
+
|
63
|
+
val = getattr(MistralForCausalLM, __name)
|
64
|
+
|
65
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
66
|
+
return redirect(val)
|
67
|
+
|
68
|
+
return val
|
@@ -70,7 +70,7 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
70
70
|
auto_model_class = AutoModelForMaskedLM
|
71
71
|
|
72
72
|
@classmethod
|
73
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
73
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
74
74
|
return _Wav2Vec2(model).eval()
|
75
75
|
|
76
76
|
@classmethod
|
@@ -57,7 +57,6 @@ class _WhisperAttention(WhisperAttention):
|
|
57
57
|
cache_position: Optional[torch.Tensor] = None,
|
58
58
|
**kwargs,
|
59
59
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
60
|
-
|
61
60
|
bsz, tgt_len, _ = hidden_states.size()
|
62
61
|
is_cross_attention = key_value_states is not None
|
63
62
|
|
@@ -123,7 +122,6 @@ class _WhisperSdpaAttention(WhisperSdpaAttention):
|
|
123
122
|
cache_position: Optional[torch.Tensor] = None,
|
124
123
|
**kwargs,
|
125
124
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
126
|
-
|
127
125
|
bsz, tgt_len, _ = hidden_states.size()
|
128
126
|
|
129
127
|
is_cross_attention = key_value_states is not None
|
@@ -189,7 +187,6 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
|
|
189
187
|
cache_position: Optional[torch.Tensor] = None,
|
190
188
|
attn_impl: str = "eager",
|
191
189
|
) -> torch.Tensor:
|
192
|
-
|
193
190
|
# Self Attention Block
|
194
191
|
residual = hidden_states
|
195
192
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
@@ -248,7 +245,6 @@ class _WhisperDecoder(WhisperDecoder):
|
|
248
245
|
attn_impl: str = "eager",
|
249
246
|
**kwargs,
|
250
247
|
):
|
251
|
-
|
252
248
|
input_shape = input_ids.size()
|
253
249
|
input_ids = input_ids.view(-1, input_shape[-1])
|
254
250
|
|
@@ -312,7 +308,6 @@ class _WhisperDecoderWrapper(torch.nn.Module):
|
|
312
308
|
self_kv_cache: torch.Tensor,
|
313
309
|
cross_kv_cache: torch.Tensor,
|
314
310
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
315
|
-
|
316
311
|
# prepare past_key_values
|
317
312
|
kv_cache = ()
|
318
313
|
for i in range(0, self.num_layers * 2, 2):
|
@@ -367,7 +362,6 @@ class _WhisperEncoderWrapper(torch.nn.Module):
|
|
367
362
|
self,
|
368
363
|
input_features: Optional[torch.LongTensor] = None,
|
369
364
|
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
370
|
-
|
371
365
|
encoder_outputs = self.encoder(input_features=input_features)
|
372
366
|
last_hidden_states = encoder_outputs[0]
|
373
367
|
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_xlm_roberta import RBLNXLMRobertaModel
|