optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
- optimum/rbln/diffusers/models/controlnet.py +36 -62
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +117 -144
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -28
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
- optimum_rbln-0.1.13.dist-info/RECORD +107 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/RECORD +0 -93
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -20,15 +20,17 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
import functools
|
23
24
|
import glob
|
24
|
-
import
|
25
|
-
|
25
|
+
import inspect
|
26
|
+
import os
|
26
27
|
from dataclasses import dataclass
|
27
28
|
from pathlib import Path
|
28
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
29
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
29
30
|
|
30
|
-
import rebel
|
31
|
-
import torch
|
31
|
+
import rebel
|
32
|
+
import torch
|
33
|
+
import transformers
|
32
34
|
from safetensors.torch import load_file
|
33
35
|
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
34
36
|
from transformers.modeling_utils import no_init_weights
|
@@ -36,11 +38,13 @@ from transformers.utils import ModelOutput
|
|
36
38
|
|
37
39
|
from ....modeling_base import RBLNModel
|
38
40
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
41
|
+
from ....utils.logging import get_logger
|
39
42
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
40
43
|
from ....utils.timer_utils import rbln_timer
|
44
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
41
45
|
|
42
46
|
|
43
|
-
logger =
|
47
|
+
logger = get_logger()
|
44
48
|
|
45
49
|
if TYPE_CHECKING:
|
46
50
|
from transformers import (
|
@@ -97,22 +101,50 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
97
101
|
@dataclass
|
98
102
|
class RBLNDecoderOnlyOutput(ModelOutput):
|
99
103
|
logits: torch.FloatTensor = None
|
100
|
-
|
104
|
+
generate_idx: torch.Tensor = None
|
101
105
|
|
102
106
|
|
103
|
-
class RBLNDecoderOnlyModelForCausalLM(RBLNModel
|
107
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
104
108
|
"""
|
105
|
-
|
106
|
-
This
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
109
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
110
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
111
|
+
|
112
|
+
The class provides core functionality for:
|
113
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
114
|
+
2. Handling the compilation process for RBLN devices
|
115
|
+
3. Managing inference operations for causal language modeling
|
116
|
+
|
117
|
+
This class inherits from RBLNModel and implements specific methods required for
|
118
|
+
decoder-only architectures and causal language modeling tasks.
|
119
|
+
|
120
|
+
Note:
|
121
|
+
- This class is designed to be subclassed by specific model implementations
|
122
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
123
|
+
- Subclasses should implement model-specific conversion logic.
|
124
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
112
125
|
"""
|
113
126
|
|
114
127
|
main_input_name = "input_ids"
|
115
128
|
auto_model_class = AutoModelForCausalLM
|
129
|
+
_decoder_wrapper_cls = DecoderOnlyWrapper
|
130
|
+
_original_cls = None
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
@property
|
134
|
+
def original_cls(cls):
|
135
|
+
"""
|
136
|
+
Lazily loads and caches the corresponding Hugging Face model class.
|
137
|
+
Removes 'RBLN' prefix from the class name to get the original class name
|
138
|
+
(e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
|
139
|
+
the transformers module.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
type: The original Hugging Face model class
|
143
|
+
"""
|
144
|
+
if cls._original_cls is None:
|
145
|
+
hf_original_cls_name = cls.__name__[4:]
|
146
|
+
cls._original_cls = getattr(transformers, hf_original_cls_name)
|
147
|
+
return cls._original_cls
|
116
148
|
|
117
149
|
def __post_init__(self, **kwargs):
|
118
150
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
@@ -231,6 +263,26 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
231
263
|
|
232
264
|
return 0
|
233
265
|
|
266
|
+
def __getattr__(self, __name: str) -> Any:
|
267
|
+
"""
|
268
|
+
Special method to delegate attribute access to the original Huggingface LM class.
|
269
|
+
This method is called when an attribute is not found in the current instance's dictionary.
|
270
|
+
It enables transparent access to the original model's attributes and methods while maintaining
|
271
|
+
proper method binding.
|
272
|
+
|
273
|
+
The method implements a delegation pattern that:
|
274
|
+
1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
275
|
+
2. For other attributes: Returns them directly from the original class
|
276
|
+
"""
|
277
|
+
|
278
|
+
def redirect(func):
|
279
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
280
|
+
|
281
|
+
val = getattr(self.original_cls, __name)
|
282
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
283
|
+
return redirect(val)
|
284
|
+
return val
|
285
|
+
|
234
286
|
@classmethod
|
235
287
|
def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
|
236
288
|
rbln_kwargs = kwargs.get("rbln_kwargs", {})
|
@@ -243,6 +295,64 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
243
295
|
|
244
296
|
return model
|
245
297
|
|
298
|
+
def validate_quantization_config(quantize_config):
|
299
|
+
if quantize_config is not None:
|
300
|
+
q_format = quantize_config.get("format")
|
301
|
+
q_precision = quantize_config.get("precision")
|
302
|
+
|
303
|
+
if q_format not in SUPPORTED_QUANTIZATIONS:
|
304
|
+
raise ValueError(
|
305
|
+
f"Invalid quantization format: {q_format}. "
|
306
|
+
f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
|
307
|
+
)
|
308
|
+
|
309
|
+
if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
310
|
+
raise ValueError(
|
311
|
+
f"Invalid precision: {q_precision} for format: {q_format}. "
|
312
|
+
f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
|
313
|
+
)
|
314
|
+
|
315
|
+
return quantize_config
|
316
|
+
|
317
|
+
@classmethod
|
318
|
+
def set_quantize_env(cls, quantize_config):
|
319
|
+
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
320
|
+
quantize_config = cls.validate_quantization_config(quantize_config)
|
321
|
+
if quantize_config is not None:
|
322
|
+
q_precision = quantize_config.get("precision")
|
323
|
+
quant_bits = q_precision.split("w")[1].split("a")[0]
|
324
|
+
os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
|
325
|
+
return RBLN_QUANT_BITS_ENV
|
326
|
+
return None
|
327
|
+
|
328
|
+
@classmethod
|
329
|
+
def reset_quantize_env(cls, env_var_name):
|
330
|
+
if env_var_name is not None and env_var_name in os.environ:
|
331
|
+
del os.environ[env_var_name]
|
332
|
+
|
333
|
+
@classmethod
|
334
|
+
def manage_quantize_env(cls, func):
|
335
|
+
@functools.wraps(func)
|
336
|
+
def wrapper(*args, **kwargs):
|
337
|
+
quantize_config = kwargs.get("quantize_config")
|
338
|
+
quantize_env_var = cls.set_quantize_env(quantize_config)
|
339
|
+
try:
|
340
|
+
return func(*args, **kwargs)
|
341
|
+
finally:
|
342
|
+
cls.reset_quantize_env(quantize_env_var)
|
343
|
+
|
344
|
+
return wrapper
|
345
|
+
|
346
|
+
@classmethod
|
347
|
+
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
348
|
+
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
349
|
+
|
350
|
+
# If the model wrapper supports rbln-custom-flash-attention
|
351
|
+
if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
|
352
|
+
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
353
|
+
|
354
|
+
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
355
|
+
|
246
356
|
@classmethod
|
247
357
|
@torch.inference_mode()
|
248
358
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
@@ -252,14 +362,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
252
362
|
prefill_rbln_compile_config = rbln_compile_configs[0]
|
253
363
|
dec_rbln_compile_config = rbln_compile_configs[1]
|
254
364
|
|
255
|
-
@rbln_timer("
|
365
|
+
@rbln_timer("JIT trace")
|
256
366
|
def get_scripted_model():
|
257
367
|
# This function is nested to dealloc the example inputs before compilation.
|
368
|
+
# FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
|
258
369
|
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
259
|
-
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=
|
260
|
-
|
261
|
-
batch_index = 3
|
262
|
-
dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
370
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
263
371
|
|
264
372
|
prefill_scripted_model = torch.jit.trace(
|
265
373
|
wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
|
@@ -271,7 +379,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
271
379
|
|
272
380
|
prefill_scripted_model, dec_scripted_model = get_scripted_model()
|
273
381
|
|
274
|
-
@rbln_timer("
|
382
|
+
@rbln_timer("Model conversion")
|
275
383
|
def scripted_model_to_ir():
|
276
384
|
prefill_ir = rebel.torchscript_to_ir(
|
277
385
|
prefill_scripted_model,
|
@@ -291,7 +399,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
291
399
|
for i in range(model.config.num_hidden_layers * 2)
|
292
400
|
]
|
293
401
|
|
294
|
-
|
402
|
+
# Extract quantize_config from rbln_config
|
403
|
+
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
404
|
+
|
405
|
+
@cls.manage_quantize_env
|
406
|
+
def compile_model(*args, **kwargs):
|
407
|
+
# Remove quantize_config from kwargs
|
408
|
+
kwargs.pop("quantize_config", None)
|
409
|
+
|
410
|
+
# Call rebel.compile with the updated kwargs
|
411
|
+
return rebel.compile(*args, **kwargs)
|
412
|
+
|
413
|
+
compiled_model = compile_model(
|
295
414
|
prefill_ir,
|
296
415
|
dec_ir,
|
297
416
|
connections=connections,
|
@@ -299,7 +418,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
299
418
|
npu=prefill_rbln_compile_config.npu,
|
300
419
|
tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
|
301
420
|
use_weight_sharing=True,
|
421
|
+
quantize_config=quantize_config,
|
302
422
|
)
|
423
|
+
|
303
424
|
return compiled_model
|
304
425
|
|
305
426
|
@classmethod
|
@@ -314,6 +435,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
314
435
|
rbln_quantization = rbln_kwargs.get("quantization", None)
|
315
436
|
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
316
437
|
|
438
|
+
rbln_quantization = cls.validate_quantization_config(rbln_quantization)
|
439
|
+
|
317
440
|
prefill_chunk_size = 128
|
318
441
|
if rbln_max_seq_len is None:
|
319
442
|
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
@@ -330,16 +453,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
330
453
|
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
331
454
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
332
455
|
|
333
|
-
if rbln_quantization is not None:
|
334
|
-
q_format = rbln_quantization.get("format", None)
|
335
|
-
q_precision = rbln_quantization.get("precision", None)
|
336
|
-
|
337
|
-
if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
338
|
-
raise ValueError(
|
339
|
-
f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
|
340
|
-
f"Possible: {SUPPORTED_QUANTIZATIONS}"
|
341
|
-
)
|
342
|
-
|
343
456
|
def get_input_info(
|
344
457
|
batch_size,
|
345
458
|
query_length,
|
@@ -439,50 +552,41 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
439
552
|
def prepare_inputs_for_generation(
|
440
553
|
self,
|
441
554
|
input_ids: torch.LongTensor,
|
442
|
-
|
555
|
+
generate_idx: Optional[torch.Tensor] = None,
|
443
556
|
attention_mask: Optional[torch.LongTensor] = None,
|
444
557
|
inputs_embeds: Optional[torch.Tensor] = None,
|
445
558
|
**kwargs,
|
446
559
|
):
|
447
560
|
model_inputs = {}
|
448
|
-
|
449
|
-
if past_cached_length is None:
|
450
|
-
# huggingface make dummy_input_ids if model_input_name is "input_embeds"
|
451
|
-
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L469
|
452
|
-
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
453
|
-
input_tensors = inputs_embeds
|
454
|
-
else:
|
455
|
-
input_tensors = input_ids
|
561
|
+
is_prefill_phase = generate_idx is None
|
456
562
|
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
|
461
|
-
for i in range(batch_size):
|
462
|
-
input_tensor = input_tensors[i]
|
463
|
-
input_tensor = input_tensor[attention_mask[i] == 1]
|
464
|
-
valid_len = input_tensor.shape[0]
|
465
|
-
cache_position = torch.arange(0, valid_len, dtype=torch.int32)
|
466
|
-
past_cached_length[i] = valid_len
|
467
|
-
l_input_tensors.append(input_tensor.unsqueeze(0))
|
468
|
-
cache_positions.append(cache_position.unsqueeze(0))
|
469
|
-
|
470
|
-
input_tensors = l_input_tensors
|
471
|
-
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
472
|
-
model_inputs.update({"inputs_embeds": input_tensors, "input_ids": input_ids})
|
473
|
-
else:
|
474
|
-
model_inputs.update({"input_ids": input_tensors, "inputs_embeds": inputs_embeds})
|
475
|
-
# decoder phase
|
563
|
+
if is_prefill_phase:
|
564
|
+
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
565
|
+
cache_position = None
|
476
566
|
else:
|
567
|
+
if inputs_embeds is not None:
|
568
|
+
raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
|
569
|
+
|
477
570
|
input_ids = input_ids[:, -1:]
|
478
|
-
|
479
|
-
|
571
|
+
cache_position = generate_idx
|
572
|
+
generate_idx = generate_idx + 1
|
573
|
+
model_inputs.update({"input_ids": input_ids})
|
574
|
+
|
575
|
+
if inputs_embeds is not None:
|
576
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
577
|
+
model_inputs.update({"inputs_embeds": inputs_embeds})
|
578
|
+
else:
|
579
|
+
raise ValueError(
|
580
|
+
"The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
|
581
|
+
)
|
582
|
+
else:
|
480
583
|
model_inputs.update({"input_ids": input_ids})
|
481
584
|
|
482
585
|
model_inputs.update(
|
483
586
|
{
|
484
|
-
"
|
485
|
-
"
|
587
|
+
"attention_mask": attention_mask,
|
588
|
+
"cache_position": cache_position,
|
589
|
+
"generate_idx": generate_idx,
|
486
590
|
}
|
487
591
|
)
|
488
592
|
|
@@ -494,42 +598,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
494
598
|
model_kwargs: Dict[str, Any],
|
495
599
|
**kwargs,
|
496
600
|
) -> Dict[str, Any]:
|
497
|
-
# update
|
498
|
-
model_kwargs["
|
601
|
+
# update generate_idx
|
602
|
+
model_kwargs["generate_idx"] = outputs.generate_idx
|
499
603
|
|
500
604
|
return model_kwargs
|
501
605
|
|
502
606
|
def forward(
|
503
607
|
self,
|
504
|
-
input_ids: Optional[
|
505
|
-
inputs_embeds: Optional[
|
506
|
-
cache_position:
|
608
|
+
input_ids: Optional[torch.LongTensor] = None,
|
609
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
610
|
+
cache_position: Optional[torch.Tensor] = None,
|
611
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
612
|
+
generate_idx: Optional[torch.Tensor] = None,
|
613
|
+
# from llava_next forward args
|
507
614
|
batch_idx: Optional[int] = None,
|
508
|
-
past_cached_length: Optional[torch.Tensor] = None,
|
509
615
|
**kwargs,
|
510
616
|
) -> Tuple[torch.FloatTensor]:
|
511
|
-
# prefll
|
512
|
-
if
|
617
|
+
# prefll
|
618
|
+
if cache_position is None:
|
513
619
|
logits = []
|
514
|
-
input_tensors =
|
515
|
-
|
620
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
621
|
+
batch_size = input_tensors.shape[0]
|
622
|
+
|
623
|
+
for b_idx in range(batch_size):
|
624
|
+
# Transform inputs as vllm format
|
625
|
+
if attention_mask is not None:
|
626
|
+
input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
627
|
+
else:
|
628
|
+
input_tensor = input_tensors[b_idx : b_idx + 1]
|
629
|
+
|
630
|
+
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
631
|
+
|
516
632
|
logit = self._forward_prefill(
|
517
633
|
input_ids=input_tensor if inputs_embeds is None else None,
|
518
634
|
inputs_embeds=input_tensor if inputs_embeds is not None else None,
|
519
|
-
cache_position=
|
520
|
-
batch_idx=batch_idx,
|
635
|
+
cache_position=cache_position,
|
636
|
+
batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
|
521
637
|
)
|
522
638
|
logits.append(logit)
|
523
639
|
logits = torch.cat(logits, dim=0)
|
524
|
-
#
|
525
|
-
elif cache_position.shape[-1] > 1:
|
526
|
-
logits = self._forward_prefill(
|
527
|
-
input_ids=input_ids,
|
528
|
-
inputs_embeds=inputs_embeds,
|
529
|
-
cache_position=cache_position,
|
530
|
-
batch_idx=batch_idx,
|
531
|
-
)
|
532
|
-
# common decoder
|
640
|
+
# decoder
|
533
641
|
else:
|
534
642
|
logits = self._forward_decoder(
|
535
643
|
input_ids=input_ids,
|
@@ -539,7 +647,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
539
647
|
|
540
648
|
return RBLNDecoderOnlyOutput(
|
541
649
|
logits=logits,
|
542
|
-
|
650
|
+
generate_idx=generate_idx,
|
543
651
|
)
|
544
652
|
|
545
653
|
def _forward_prefill(
|
@@ -567,23 +675,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
567
675
|
torch.empty(size=[], dtype=torch.int16, device="cpu"),
|
568
676
|
]
|
569
677
|
|
570
|
-
|
571
|
-
model_input_name = "inputs_embeds"
|
572
|
-
else:
|
573
|
-
model_input_name = "input_ids"
|
574
|
-
|
575
|
-
input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
|
576
|
-
|
678
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
577
679
|
query_length = input_tensors.shape[1]
|
578
|
-
|
680
|
+
_attention_mask = self.prefill_attention_mask.clone()
|
681
|
+
|
579
682
|
for step in range(0, query_length, self.prefill_chunk_size):
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
if
|
584
|
-
input_tensors = torch.nn.functional.pad(input_tensors, (0,
|
683
|
+
# pad input_tensors & cache_position for prefill_chunk
|
684
|
+
if (step + self.prefill_chunk_size) > query_length:
|
685
|
+
pad_to_chunk = step + self.prefill_chunk_size - query_length
|
686
|
+
if inputs_embeds is not None:
|
687
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
|
585
688
|
else:
|
586
|
-
input_tensors = torch.nn.functional.pad(input_tensors, (0,
|
689
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
|
587
690
|
|
588
691
|
cache_position = torch.cat(
|
589
692
|
[
|
@@ -597,25 +700,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
597
700
|
dim=-1,
|
598
701
|
)
|
599
702
|
|
600
|
-
|
601
|
-
|
703
|
+
# slice input_tensor & cache_position with prefill_chunk_size
|
704
|
+
_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
|
705
|
+
_cache_position = cache_position[:, step : step + self.prefill_chunk_size]
|
602
706
|
|
707
|
+
# update attention_mask
|
603
708
|
if step >= self.prefill_chunk_size:
|
604
|
-
|
605
|
-
|
709
|
+
_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
710
|
+
_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
606
711
|
|
607
|
-
query_idx = query_length % self.prefill_chunk_size
|
712
|
+
query_idx = (query_length - 1) % self.prefill_chunk_size
|
608
713
|
|
609
714
|
logits, _ = self.prefill_decoder(
|
610
|
-
input_ids=
|
611
|
-
inputs_embeds=
|
612
|
-
attention_mask=
|
613
|
-
cache_position=
|
715
|
+
input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
|
716
|
+
inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
|
717
|
+
attention_mask=_attention_mask.contiguous(),
|
718
|
+
cache_position=_cache_position.contiguous(),
|
614
719
|
batch_position=torch.tensor(batch_idx, dtype=torch.int16),
|
615
720
|
query_idx=torch.tensor(query_idx, dtype=torch.int16),
|
616
721
|
out=out_buffers,
|
617
722
|
)
|
618
723
|
|
724
|
+
# update decoder_attn_mask with preprocessed kv-cache length in prefill phase
|
619
725
|
self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
|
620
726
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
621
727
|
|
@@ -627,11 +733,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
627
733
|
inputs_embeds: torch.Tensor = None,
|
628
734
|
cache_position: torch.Tensor = None,
|
629
735
|
) -> torch.FloatTensor:
|
630
|
-
|
631
|
-
model_input_name = "inputs_embeds"
|
632
|
-
else:
|
633
|
-
model_input_name = "input_ids"
|
634
|
-
input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
|
736
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
635
737
|
|
636
738
|
batch_size = input_tensors.shape[0]
|
637
739
|
|
@@ -640,8 +742,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
640
742
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
641
743
|
|
642
744
|
logits, _ = self.decoder(
|
643
|
-
input_ids=input_tensors.contiguous() if
|
644
|
-
inputs_embeds=input_tensors.contiguous() if
|
745
|
+
input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
|
746
|
+
inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
|
645
747
|
attention_mask=self.dec_attn_mask.contiguous(),
|
646
748
|
cache_position=cache_position.contiguous(),
|
647
749
|
batch_position=torch.tensor(0, dtype=torch.int16),
|
@@ -649,3 +751,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
649
751
|
)
|
650
752
|
|
651
753
|
return logits
|
754
|
+
|
755
|
+
def vllm_forward(
|
756
|
+
self,
|
757
|
+
input_ids: torch.LongTensor = None,
|
758
|
+
inputs_embeds: torch.Tensor = None,
|
759
|
+
cache_position: torch.Tensor = None,
|
760
|
+
batch_idx: Optional[int] = None,
|
761
|
+
**kwargs,
|
762
|
+
) -> Tuple[torch.FloatTensor]:
|
763
|
+
# prefll
|
764
|
+
if cache_position.shape[-1] > 1:
|
765
|
+
logits = self._forward_prefill(
|
766
|
+
input_ids=input_ids,
|
767
|
+
inputs_embeds=inputs_embeds,
|
768
|
+
cache_position=cache_position,
|
769
|
+
batch_idx=batch_idx,
|
770
|
+
)
|
771
|
+
# decoder
|
772
|
+
else:
|
773
|
+
logits = self._forward_decoder(
|
774
|
+
input_ids=input_ids,
|
775
|
+
inputs_embeds=inputs_embeds,
|
776
|
+
cache_position=cache_position,
|
777
|
+
)
|
778
|
+
|
779
|
+
return RBLNDecoderOnlyOutput(
|
780
|
+
logits=logits,
|
781
|
+
)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import os
|
25
|
+
from os import environ
|
26
|
+
|
27
|
+
|
28
|
+
this_path = os.path.abspath(__file__)
|
29
|
+
local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
|
30
|
+
environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
|
31
|
+
|
32
|
+
from .modeling_exaone import RBLNExaoneForCausalLM
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
import torch
|
24
|
+
|
25
|
+
from ....utils import logging
|
26
|
+
from ...models.decoderonly import (
|
27
|
+
DecoderOnlyAttention,
|
28
|
+
DecoderOnlyDecoderLayer,
|
29
|
+
DecoderOnlyModel,
|
30
|
+
DecoderOnlyWrapper,
|
31
|
+
RotaryEmbedding,
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
logger = logging.get_logger(__name__)
|
36
|
+
|
37
|
+
|
38
|
+
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
39
|
+
"""A wrapper class for the Exaone model with a language modeling head."""
|
40
|
+
|
41
|
+
def __init__(self, model, max_seq_len, kvcache_partition_len=None):
|
42
|
+
super(DecoderOnlyWrapper, self).__init__()
|
43
|
+
self.config = model.config
|
44
|
+
self.model = self.convert_attribute_name(model.transformer)
|
45
|
+
self.lm_head = model.lm_head
|
46
|
+
self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
47
|
+
|
48
|
+
if kvcache_partition_len is not None:
|
49
|
+
# WORKAROUND : for passing partition length as a value to the rbln compiler.
|
50
|
+
# What is actually used is the shape of this tensor.
|
51
|
+
self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
|
52
|
+
self.attn_implementation = "flash_attn_rbln"
|
53
|
+
logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
|
54
|
+
else:
|
55
|
+
self.kvcache_partition_size = None
|
56
|
+
self.attn_implementation = "eager"
|
57
|
+
|
58
|
+
@staticmethod
|
59
|
+
def convert_attribute_name(model):
|
60
|
+
model.embed_tokens = model.wte
|
61
|
+
model.norm = model.ln_f
|
62
|
+
model.layers = model.h
|
63
|
+
|
64
|
+
for layer in model.layers:
|
65
|
+
layer.input_layernorm = layer.ln_1
|
66
|
+
layer.self_attn = layer.attn.attention
|
67
|
+
layer.post_attention_layernorm = layer.ln_2
|
68
|
+
layer.self_attn.o_proj = layer.self_attn.out_proj
|
69
|
+
|
70
|
+
return model
|
71
|
+
|
72
|
+
def get_forward_dict(self):
|
73
|
+
forward_dict = {}
|
74
|
+
forward_dict.update(
|
75
|
+
{
|
76
|
+
"wrapper": DecoderOnlyModel.forward,
|
77
|
+
"model": DecoderOnlyDecoderLayer.forward,
|
78
|
+
"decoder_layer": DecoderOnlyAttention.forward,
|
79
|
+
}
|
80
|
+
)
|
81
|
+
return forward_dict
|