optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +22 -12
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
- optimum/rbln/diffusers/pipelines/__init__.py +22 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +164 -758
- optimum/rbln/modeling_diffusers.py +51 -122
- optimum/rbln/transformers/__init__.py +0 -2
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
- optimum/rbln/utils/decorator_utils.py +10 -6
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +15 -1
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +1 -1
- optimum/rbln/utils/submodule.py +114 -0
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -20,27 +20,23 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
-
import functools
|
24
|
-
import glob
|
25
23
|
import inspect
|
26
|
-
import os
|
27
24
|
from dataclasses import dataclass
|
28
25
|
from pathlib import Path
|
29
26
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
30
27
|
|
31
28
|
import rebel
|
32
29
|
import torch
|
33
|
-
import transformers
|
34
|
-
from safetensors.torch import load_file
|
35
30
|
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
36
31
|
from transformers.modeling_utils import no_init_weights
|
37
32
|
from transformers.utils import ModelOutput
|
38
33
|
|
39
|
-
from ....
|
34
|
+
from ....modeling import RBLNModel
|
40
35
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
41
36
|
from ....utils.logging import get_logger
|
42
37
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
43
38
|
from ....utils.timer_utils import rbln_timer
|
39
|
+
from ...utils.rbln_quantization import QuantizationManager
|
44
40
|
from .decoderonly_architecture import DecoderOnlyWrapper
|
45
41
|
|
46
42
|
|
@@ -54,12 +50,6 @@ if TYPE_CHECKING:
|
|
54
50
|
PretrainedConfig,
|
55
51
|
)
|
56
52
|
|
57
|
-
SUPPORTED_QUANTIZATIONS = {
|
58
|
-
"rbln": [
|
59
|
-
"w4a16",
|
60
|
-
],
|
61
|
-
}
|
62
|
-
|
63
53
|
|
64
54
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
65
55
|
mandatory_members = ["main_input_name", "embed_tokens"]
|
@@ -127,24 +117,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
127
117
|
main_input_name = "input_ids"
|
128
118
|
auto_model_class = AutoModelForCausalLM
|
129
119
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
130
|
-
|
131
|
-
|
132
|
-
@classmethod
|
133
|
-
@property
|
134
|
-
def original_cls(cls):
|
135
|
-
"""
|
136
|
-
Lazily loads and caches the corresponding Hugging Face model class.
|
137
|
-
Removes 'RBLN' prefix from the class name to get the original class name
|
138
|
-
(e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
|
139
|
-
the transformers module.
|
140
|
-
|
141
|
-
Returns:
|
142
|
-
type: The original Hugging Face model class
|
143
|
-
"""
|
144
|
-
if cls._original_cls is None:
|
145
|
-
hf_original_cls_name = cls.__name__[4:]
|
146
|
-
cls._original_cls = getattr(transformers, hf_original_cls_name)
|
147
|
-
return cls._original_cls
|
120
|
+
_use_rotary_emb = True
|
148
121
|
|
149
122
|
def __post_init__(self, **kwargs):
|
150
123
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
@@ -203,6 +176,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
203
176
|
def get_quantized_model(
|
204
177
|
cls,
|
205
178
|
model_id: str,
|
179
|
+
config: Optional[PretrainedConfig] = None,
|
206
180
|
use_auth_token: Optional[Union[bool, str]] = None,
|
207
181
|
revision: Optional[str] = None,
|
208
182
|
force_download: bool = False,
|
@@ -212,57 +186,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
212
186
|
trust_remote_code: bool = False,
|
213
187
|
**kwargs,
|
214
188
|
):
|
215
|
-
from ...utils.rbln_quantization import
|
189
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
216
190
|
|
217
191
|
kwargs = cls.update_kwargs(kwargs)
|
218
192
|
|
219
|
-
config
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
193
|
+
if config is None:
|
194
|
+
config = AutoConfig.from_pretrained(
|
195
|
+
model_id,
|
196
|
+
use_auth_token=use_auth_token,
|
197
|
+
revision=revision,
|
198
|
+
force_download=force_download,
|
199
|
+
cache_dir=cache_dir,
|
200
|
+
trust_remote_code=trust_remote_code,
|
201
|
+
**kwargs,
|
202
|
+
)
|
228
203
|
|
229
204
|
with no_init_weights():
|
230
205
|
model = AutoModelForCausalLM.from_config(config)
|
231
206
|
|
232
|
-
|
233
|
-
|
234
|
-
n_layer = kwargs.get("num_hidden_layers", None)
|
235
|
-
cls._load_weights_directly_to_model(model, model_id, n_layer)
|
207
|
+
prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
|
236
208
|
|
237
209
|
return model
|
238
210
|
|
239
|
-
def _load_weights_directly_to_model(model, model_id, n_layer=None):
|
240
|
-
"""
|
241
|
-
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
242
|
-
"""
|
243
|
-
|
244
|
-
model_params = dict(model.named_parameters(recurse=True))
|
245
|
-
model_buffers = dict(model.named_buffers(recurse=True))
|
246
|
-
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
247
|
-
|
248
|
-
target_layers = list(range(n_layer)) if n_layer is not None else None
|
249
|
-
|
250
|
-
for safetensor_file in safetensor_files:
|
251
|
-
file_data = load_file(safetensor_file)
|
252
|
-
for key, value in file_data.items():
|
253
|
-
if target_layers is not None:
|
254
|
-
parts = key.split(".")
|
255
|
-
|
256
|
-
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
257
|
-
continue
|
258
|
-
|
259
|
-
if key in model_params:
|
260
|
-
model_params[key].data.copy_(value)
|
261
|
-
elif key in model_buffers:
|
262
|
-
model_buffers[key].data.copy_(value)
|
263
|
-
|
264
|
-
return 0
|
265
|
-
|
266
211
|
def __getattr__(self, __name: str) -> Any:
|
267
212
|
"""
|
268
213
|
Special method to delegate attribute access to the original Huggingface LM class.
|
@@ -278,7 +223,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
278
223
|
def redirect(func):
|
279
224
|
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
280
225
|
|
281
|
-
val = getattr(self.
|
226
|
+
val = getattr(self.hf_class, __name, None) or getattr(PreTrainedModel, __name)
|
282
227
|
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
283
228
|
return redirect(val)
|
284
229
|
return val
|
@@ -295,54 +240,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
295
240
|
|
296
241
|
return model
|
297
242
|
|
298
|
-
def validate_quantization_config(quantize_config):
|
299
|
-
if quantize_config is not None:
|
300
|
-
q_format = quantize_config.get("format")
|
301
|
-
q_precision = quantize_config.get("precision")
|
302
|
-
|
303
|
-
if q_format not in SUPPORTED_QUANTIZATIONS:
|
304
|
-
raise ValueError(
|
305
|
-
f"Invalid quantization format: {q_format}. "
|
306
|
-
f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
|
307
|
-
)
|
308
|
-
|
309
|
-
if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
310
|
-
raise ValueError(
|
311
|
-
f"Invalid precision: {q_precision} for format: {q_format}. "
|
312
|
-
f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
|
313
|
-
)
|
314
|
-
|
315
|
-
return quantize_config
|
316
|
-
|
317
|
-
@classmethod
|
318
|
-
def set_quantize_env(cls, quantize_config):
|
319
|
-
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
320
|
-
quantize_config = cls.validate_quantization_config(quantize_config)
|
321
|
-
if quantize_config is not None:
|
322
|
-
q_precision = quantize_config.get("precision")
|
323
|
-
quant_bits = q_precision.split("w")[1].split("a")[0]
|
324
|
-
os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
|
325
|
-
return RBLN_QUANT_BITS_ENV
|
326
|
-
return None
|
327
|
-
|
328
|
-
@classmethod
|
329
|
-
def reset_quantize_env(cls, env_var_name):
|
330
|
-
if env_var_name is not None and env_var_name in os.environ:
|
331
|
-
del os.environ[env_var_name]
|
332
|
-
|
333
|
-
@classmethod
|
334
|
-
def manage_quantize_env(cls, func):
|
335
|
-
@functools.wraps(func)
|
336
|
-
def wrapper(*args, **kwargs):
|
337
|
-
quantize_config = kwargs.get("quantize_config")
|
338
|
-
quantize_env_var = cls.set_quantize_env(quantize_config)
|
339
|
-
try:
|
340
|
-
return func(*args, **kwargs)
|
341
|
-
finally:
|
342
|
-
cls.reset_quantize_env(quantize_env_var)
|
343
|
-
|
344
|
-
return wrapper
|
345
|
-
|
346
243
|
@classmethod
|
347
244
|
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
348
245
|
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
@@ -351,6 +248,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
351
248
|
if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
|
352
249
|
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
353
250
|
|
251
|
+
wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
|
252
|
+
|
354
253
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
355
254
|
|
356
255
|
@classmethod
|
@@ -369,9 +268,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
369
268
|
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
370
269
|
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
371
270
|
|
271
|
+
wrapped_model.phase = "prefill"
|
372
272
|
prefill_scripted_model = torch.jit.trace(
|
373
273
|
wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
|
374
274
|
)
|
275
|
+
wrapped_model.phase = "decode"
|
375
276
|
dec_scripted_model = torch.jit.trace(
|
376
277
|
wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
|
377
278
|
)
|
@@ -394,6 +295,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
394
295
|
prefill_ir, dec_ir = scripted_model_to_ir()
|
395
296
|
# Caching prefill_decoder/decoder I/O
|
396
297
|
cache_index_offset = 5
|
298
|
+
|
397
299
|
connections = [
|
398
300
|
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
399
301
|
for i in range(model.config.num_hidden_layers * 2)
|
@@ -402,7 +304,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
402
304
|
# Extract quantize_config from rbln_config
|
403
305
|
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
404
306
|
|
405
|
-
@
|
307
|
+
@QuantizationManager.with_quantization_env
|
406
308
|
def compile_model(*args, **kwargs):
|
407
309
|
# Remove quantize_config from kwargs
|
408
310
|
kwargs.pop("quantize_config", None)
|
@@ -432,10 +334,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
432
334
|
) -> RBLNConfig:
|
433
335
|
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
434
336
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
435
|
-
rbln_quantization = rbln_kwargs.get("quantization", None)
|
436
337
|
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
437
|
-
|
438
|
-
rbln_quantization = cls.validate_quantization_config(rbln_quantization)
|
338
|
+
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
439
339
|
|
440
340
|
prefill_chunk_size = 128
|
441
341
|
if rbln_max_seq_len is None:
|
@@ -610,8 +510,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
610
510
|
cache_position: Optional[torch.Tensor] = None,
|
611
511
|
attention_mask: Optional[torch.LongTensor] = None,
|
612
512
|
generate_idx: Optional[torch.Tensor] = None,
|
613
|
-
# from llava_next forward args
|
614
|
-
batch_idx: Optional[int] = None,
|
615
513
|
**kwargs,
|
616
514
|
) -> Tuple[torch.FloatTensor]:
|
617
515
|
# prefll
|
@@ -633,7 +531,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
633
531
|
input_ids=input_tensor if inputs_embeds is None else None,
|
634
532
|
inputs_embeds=input_tensor if inputs_embeds is not None else None,
|
635
533
|
cache_position=cache_position,
|
636
|
-
batch_idx=b_idx
|
534
|
+
batch_idx=b_idx,
|
637
535
|
)
|
638
536
|
logits.append(logit)
|
639
537
|
logits = torch.cat(logits, dim=0)
|
@@ -734,11 +632,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
734
632
|
cache_position: torch.Tensor = None,
|
735
633
|
) -> torch.FloatTensor:
|
736
634
|
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
635
|
+
if input_tensors is None:
|
636
|
+
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
737
637
|
|
738
638
|
batch_size = input_tensors.shape[0]
|
639
|
+
if batch_size != self.batch_size:
|
640
|
+
raise RuntimeError(
|
641
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
642
|
+
)
|
643
|
+
|
644
|
+
if batch_size != cache_position.shape[0]:
|
645
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
739
646
|
|
740
647
|
for b_idx in range(batch_size):
|
741
648
|
decoding_step = cache_position[b_idx].item()
|
649
|
+
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
650
|
+
raise ValueError(
|
651
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
652
|
+
)
|
742
653
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
743
654
|
|
744
655
|
logits, _ = self.decoder(
|
@@ -751,31 +662,3 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
751
662
|
)
|
752
663
|
|
753
664
|
return logits
|
754
|
-
|
755
|
-
def vllm_forward(
|
756
|
-
self,
|
757
|
-
input_ids: torch.LongTensor = None,
|
758
|
-
inputs_embeds: torch.Tensor = None,
|
759
|
-
cache_position: torch.Tensor = None,
|
760
|
-
batch_idx: Optional[int] = None,
|
761
|
-
**kwargs,
|
762
|
-
) -> Tuple[torch.FloatTensor]:
|
763
|
-
# prefll
|
764
|
-
if cache_position.shape[-1] > 1:
|
765
|
-
logits = self._forward_prefill(
|
766
|
-
input_ids=input_ids,
|
767
|
-
inputs_embeds=inputs_embeds,
|
768
|
-
cache_position=cache_position,
|
769
|
-
batch_idx=batch_idx,
|
770
|
-
)
|
771
|
-
# decoder
|
772
|
-
else:
|
773
|
-
logits = self._forward_decoder(
|
774
|
-
input_ids=input_ids,
|
775
|
-
inputs_embeds=inputs_embeds,
|
776
|
-
cache_position=cache_position,
|
777
|
-
)
|
778
|
-
|
779
|
-
return RBLNDecoderOnlyOutput(
|
780
|
-
logits=logits,
|
781
|
-
)
|
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
|
27
27
|
from transformers import AutoModelForDepthEstimation
|
28
28
|
from transformers.modeling_outputs import DepthEstimatorOutput
|
29
29
|
|
30
|
-
from ....
|
30
|
+
from ....modeling import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|
@@ -20,62 +20,78 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
-
import
|
23
|
+
from typing import TYPE_CHECKING
|
24
|
+
|
25
|
+
import torch.nn as nn
|
24
26
|
|
25
27
|
from ....utils import logging
|
26
|
-
from ...models.decoderonly import (
|
28
|
+
from ...models.decoderonly.decoderonly_architecture import (
|
27
29
|
DecoderOnlyAttention,
|
28
|
-
|
30
|
+
DecoderOnlyFlashAttention,
|
31
|
+
DecoderOnlyForCausalLM,
|
32
|
+
DecoderOnlyLayer,
|
29
33
|
DecoderOnlyModel,
|
30
34
|
DecoderOnlyWrapper,
|
31
|
-
RotaryEmbedding,
|
32
35
|
)
|
33
36
|
|
34
37
|
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from transformers import PreTrainedModel as ExaoneForCausalLM
|
40
|
+
|
35
41
|
logger = logging.get_logger(__name__)
|
36
42
|
|
37
43
|
|
38
44
|
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
39
45
|
"""A wrapper class for the Exaone model with a language modeling head."""
|
40
46
|
|
41
|
-
def
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
return
|
71
|
-
|
72
|
-
def
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
47
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
|
48
|
+
new_layers = []
|
49
|
+
for layer in causal_lm.transformer.h:
|
50
|
+
if self.attn_impl == "eager":
|
51
|
+
new_self_attn = ExaoneAttention(layer.attn.attention)
|
52
|
+
elif self.attn_impl == "flash_attn":
|
53
|
+
new_self_attn = ExaoneFlashAttention(
|
54
|
+
layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
58
|
+
|
59
|
+
new_layer = ExaoneLayer(layer, new_self_attn)
|
60
|
+
new_layers.append(new_layer)
|
61
|
+
new_model = ExaoneModel(causal_lm.transformer, new_layers)
|
62
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
63
|
+
return new_causal_lm
|
64
|
+
|
65
|
+
|
66
|
+
class ExaoneModel(DecoderOnlyModel):
|
67
|
+
def get_embedding(self) -> nn.Embedding:
|
68
|
+
return self._original_mod.wte
|
69
|
+
|
70
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
71
|
+
return self._original_mod.ln_f
|
72
|
+
|
73
|
+
|
74
|
+
class ExaoneLayer(DecoderOnlyLayer):
|
75
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
76
|
+
return self._original_mod.ln_1
|
77
|
+
|
78
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
79
|
+
return self._original_mod.ln_2
|
80
|
+
|
81
|
+
|
82
|
+
class ExaoneAttention(DecoderOnlyAttention):
|
83
|
+
def __post_init__(self):
|
84
|
+
self.q_proj = self._original_mod.q_proj
|
85
|
+
self.k_proj = self._original_mod.k_proj
|
86
|
+
self.v_proj = self._original_mod.v_proj
|
87
|
+
self.o_proj = self._original_mod.out_proj
|
88
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
89
|
+
|
90
|
+
|
91
|
+
class ExaoneFlashAttention(DecoderOnlyFlashAttention):
|
92
|
+
def __post_init__(self):
|
93
|
+
self.q_proj = self._original_mod.q_proj
|
94
|
+
self.k_proj = self._original_mod.k_proj
|
95
|
+
self.v_proj = self._original_mod.v_proj
|
96
|
+
self.o_proj = self._original_mod.out_proj
|
97
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
@@ -21,10 +21,12 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
|
25
|
+
from transformers import AutoModelForCausalLM
|
26
|
+
|
24
27
|
from ....utils import logging
|
25
28
|
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
29
|
from .exaone_architecture import ExaoneForCausalLMWrapper
|
27
|
-
from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
|
28
30
|
|
29
31
|
|
30
32
|
logger = logging.get_logger(__name__)
|
@@ -45,7 +47,7 @@ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
45
47
|
"""
|
46
48
|
|
47
49
|
_decoder_wrapper_cls = ExaoneForCausalLMWrapper
|
48
|
-
|
50
|
+
_hf_class = AutoModelForCausalLM
|
49
51
|
|
50
52
|
@classmethod
|
51
53
|
def from_pretrained(cls, *args, **kwargs):
|
@@ -21,113 +21,42 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
25
|
-
|
26
|
-
import
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
DecoderOnlyDecoderLayer,
|
24
|
+
from typing import TYPE_CHECKING
|
25
|
+
|
26
|
+
from ...models.decoderonly.decoderonly_architecture import (
|
27
|
+
DecoderOnlyAttention,
|
28
|
+
DecoderOnlyFlashAttention,
|
29
|
+
DecoderOnlyForCausalLM,
|
30
|
+
DecoderOnlyLayer,
|
31
|
+
DecoderOnlyModel,
|
33
32
|
DecoderOnlyWrapper,
|
34
|
-
slice_and_unsqueeze_cos_sin,
|
35
33
|
)
|
36
|
-
from ...models.decoderonly.decoderonly_architecture import DECODERONLY_ATTENTION_CLASSES
|
37
|
-
|
38
|
-
|
39
|
-
class GemmaWrapper(DecoderOnlyWrapper):
|
40
|
-
def get_forward_dict(self):
|
41
|
-
forward_dict = {}
|
42
|
-
forward_dict.update(
|
43
|
-
{
|
44
|
-
"wrapper": GemmaModel.forward,
|
45
|
-
"model": DecoderOnlyDecoderLayer.forward,
|
46
|
-
"decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
|
47
|
-
}
|
48
|
-
)
|
49
|
-
return forward_dict
|
50
|
-
|
51
|
-
|
52
|
-
class GemmaModel:
|
53
|
-
def forward(
|
54
|
-
self,
|
55
|
-
input_ids: torch.LongTensor = None,
|
56
|
-
attention_mask: Optional[torch.Tensor] = None,
|
57
|
-
position_ids: Optional[torch.LongTensor] = None,
|
58
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
59
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
60
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
61
|
-
use_cache: Optional[bool] = True,
|
62
|
-
output_attentions: Optional[bool] = False,
|
63
|
-
output_hidden_states: Optional[bool] = False,
|
64
|
-
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
65
|
-
kvcache_partition_size: Optional[torch.Tensor] = None,
|
66
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
67
|
-
rotary_pos_emb=None,
|
68
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
69
|
-
# retrieve input_ids and inputs_embeds
|
70
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
71
|
-
raise ValueError(
|
72
|
-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
73
|
-
)
|
74
|
-
|
75
|
-
# embed positions
|
76
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
77
|
-
hidden_states = inputs_embeds
|
78
34
|
|
79
|
-
##### GEMMA change from llama#####
|
80
|
-
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
81
35
|
|
82
|
-
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import GemmaForCausalLM
|
83
38
|
|
84
|
-
# get cos,sin vector
|
85
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
86
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
87
39
|
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
hidden_states = layer_outputs[0]
|
113
|
-
|
114
|
-
updated_cache = layer_outputs[2 if output_attentions else 1]
|
115
|
-
|
116
|
-
if output_attentions:
|
117
|
-
all_self_attns += (layer_outputs[1],)
|
118
|
-
|
119
|
-
hidden_states = self.norm(hidden_states)
|
120
|
-
|
121
|
-
# add hidden states from the last decoder layer
|
122
|
-
if output_hidden_states:
|
123
|
-
all_hidden_states += (hidden_states,)
|
124
|
-
|
125
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
126
|
-
next_cache = updated_cache.to_legacy_cache()
|
127
|
-
|
128
|
-
return BaseModelOutputWithPast(
|
129
|
-
last_hidden_state=hidden_states,
|
130
|
-
past_key_values=next_cache,
|
131
|
-
hidden_states=all_hidden_states,
|
132
|
-
attentions=all_self_attns,
|
133
|
-
)
|
40
|
+
class GemmaWrapper(DecoderOnlyWrapper):
|
41
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
|
42
|
+
new_layers = []
|
43
|
+
for layer in causal_lm.model.layers:
|
44
|
+
if self.attn_impl == "eager":
|
45
|
+
new_self_attn = DecoderOnlyAttention(layer.self_attn)
|
46
|
+
elif self.attn_impl == "flash_attn":
|
47
|
+
new_self_attn = DecoderOnlyFlashAttention(
|
48
|
+
layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
|
49
|
+
)
|
50
|
+
else:
|
51
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
52
|
+
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
53
|
+
new_layers.append(new_layer)
|
54
|
+
new_model = GemmaModel(causal_lm.model, new_layers)
|
55
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
56
|
+
return new_causal_lm
|
57
|
+
|
58
|
+
|
59
|
+
class GemmaModel(DecoderOnlyModel):
|
60
|
+
@property
|
61
|
+
def hidden_multiplier(self):
|
62
|
+
return self._original_mod.config.hidden_size**0.5
|