optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -20,26 +20,24 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
-
import
|
24
|
-
import glob
|
25
|
-
import os
|
26
|
-
from abc import ABC
|
23
|
+
import inspect
|
27
24
|
from dataclasses import dataclass
|
28
25
|
from pathlib import Path
|
29
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
30
27
|
|
31
|
-
import rebel
|
32
|
-
import torch
|
33
|
-
from safetensors.torch import load_file
|
28
|
+
import rebel
|
29
|
+
import torch
|
34
30
|
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
35
31
|
from transformers.modeling_utils import no_init_weights
|
36
32
|
from transformers.utils import ModelOutput
|
37
33
|
|
38
|
-
from ....
|
34
|
+
from ....modeling import RBLNModel
|
39
35
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
40
36
|
from ....utils.logging import get_logger
|
41
37
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
42
38
|
from ....utils.timer_utils import rbln_timer
|
39
|
+
from ...utils.rbln_quantization import QuantizationManager
|
40
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
43
41
|
|
44
42
|
|
45
43
|
logger = get_logger()
|
@@ -52,12 +50,6 @@ if TYPE_CHECKING:
|
|
52
50
|
PretrainedConfig,
|
53
51
|
)
|
54
52
|
|
55
|
-
SUPPORTED_QUANTIZATIONS = {
|
56
|
-
"rbln": [
|
57
|
-
"w4a16",
|
58
|
-
],
|
59
|
-
}
|
60
|
-
|
61
53
|
|
62
54
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
63
55
|
mandatory_members = ["main_input_name", "embed_tokens"]
|
@@ -102,19 +94,30 @@ class RBLNDecoderOnlyOutput(ModelOutput):
|
|
102
94
|
generate_idx: torch.Tensor = None
|
103
95
|
|
104
96
|
|
105
|
-
class RBLNDecoderOnlyModelForCausalLM(RBLNModel
|
97
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
106
98
|
"""
|
107
|
-
|
108
|
-
This
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
99
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
100
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
101
|
+
|
102
|
+
The class provides core functionality for:
|
103
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
104
|
+
2. Handling the compilation process for RBLN devices
|
105
|
+
3. Managing inference operations for causal language modeling
|
106
|
+
|
107
|
+
This class inherits from RBLNModel and implements specific methods required for
|
108
|
+
decoder-only architectures and causal language modeling tasks.
|
109
|
+
|
110
|
+
Note:
|
111
|
+
- This class is designed to be subclassed by specific model implementations
|
112
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
113
|
+
- Subclasses should implement model-specific conversion logic.
|
114
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
114
115
|
"""
|
115
116
|
|
116
117
|
main_input_name = "input_ids"
|
117
118
|
auto_model_class = AutoModelForCausalLM
|
119
|
+
_decoder_wrapper_cls = DecoderOnlyWrapper
|
120
|
+
_use_rotary_emb = True
|
118
121
|
|
119
122
|
def __post_init__(self, **kwargs):
|
120
123
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
@@ -173,6 +176,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
173
176
|
def get_quantized_model(
|
174
177
|
cls,
|
175
178
|
model_id: str,
|
179
|
+
config: Optional[PretrainedConfig] = None,
|
176
180
|
use_auth_token: Optional[Union[bool, str]] = None,
|
177
181
|
revision: Optional[str] = None,
|
178
182
|
force_download: bool = False,
|
@@ -182,56 +186,47 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
182
186
|
trust_remote_code: bool = False,
|
183
187
|
**kwargs,
|
184
188
|
):
|
185
|
-
from ...utils.rbln_quantization import
|
189
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
186
190
|
|
187
191
|
kwargs = cls.update_kwargs(kwargs)
|
188
192
|
|
189
|
-
config
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
193
|
+
if config is None:
|
194
|
+
config = AutoConfig.from_pretrained(
|
195
|
+
model_id,
|
196
|
+
use_auth_token=use_auth_token,
|
197
|
+
revision=revision,
|
198
|
+
force_download=force_download,
|
199
|
+
cache_dir=cache_dir,
|
200
|
+
trust_remote_code=trust_remote_code,
|
201
|
+
**kwargs,
|
202
|
+
)
|
198
203
|
|
199
204
|
with no_init_weights():
|
200
205
|
model = AutoModelForCausalLM.from_config(config)
|
201
206
|
|
202
|
-
|
203
|
-
|
204
|
-
n_layer = kwargs.get("num_hidden_layers", None)
|
205
|
-
cls._load_weights_directly_to_model(model, model_id, n_layer)
|
207
|
+
prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
|
206
208
|
|
207
209
|
return model
|
208
210
|
|
209
|
-
def
|
211
|
+
def __getattr__(self, __name: str) -> Any:
|
210
212
|
"""
|
211
|
-
|
213
|
+
Special method to delegate attribute access to the original Huggingface LM class.
|
214
|
+
This method is called when an attribute is not found in the current instance's dictionary.
|
215
|
+
It enables transparent access to the original model's attributes and methods while maintaining
|
216
|
+
proper method binding.
|
217
|
+
|
218
|
+
The method implements a delegation pattern that:
|
219
|
+
1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
220
|
+
2. For other attributes: Returns them directly from the original class
|
212
221
|
"""
|
213
222
|
|
214
|
-
|
215
|
-
|
216
|
-
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
217
|
-
|
218
|
-
target_layers = list(range(n_layer)) if n_layer is not None else None
|
223
|
+
def redirect(func):
|
224
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
219
225
|
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
parts = key.split(".")
|
225
|
-
|
226
|
-
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
227
|
-
continue
|
228
|
-
|
229
|
-
if key in model_params:
|
230
|
-
model_params[key].data.copy_(value)
|
231
|
-
elif key in model_buffers:
|
232
|
-
model_buffers[key].data.copy_(value)
|
233
|
-
|
234
|
-
return 0
|
226
|
+
val = getattr(self.hf_class, __name, None) or getattr(PreTrainedModel, __name)
|
227
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
228
|
+
return redirect(val)
|
229
|
+
return val
|
235
230
|
|
236
231
|
@classmethod
|
237
232
|
def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
|
@@ -245,53 +240,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
245
240
|
|
246
241
|
return model
|
247
242
|
|
248
|
-
def validate_quantization_config(quantize_config):
|
249
|
-
if quantize_config is not None:
|
250
|
-
q_format = quantize_config.get("format")
|
251
|
-
q_precision = quantize_config.get("precision")
|
252
|
-
|
253
|
-
if q_format not in SUPPORTED_QUANTIZATIONS:
|
254
|
-
raise ValueError(
|
255
|
-
f"Invalid quantization format: {q_format}. "
|
256
|
-
f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
|
257
|
-
)
|
258
|
-
|
259
|
-
if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
260
|
-
raise ValueError(
|
261
|
-
f"Invalid precision: {q_precision} for format: {q_format}. "
|
262
|
-
f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
|
263
|
-
)
|
264
|
-
|
265
|
-
return quantize_config
|
266
|
-
|
267
243
|
@classmethod
|
268
|
-
def
|
269
|
-
|
270
|
-
quantize_config = cls.validate_quantization_config(quantize_config)
|
271
|
-
if quantize_config is not None:
|
272
|
-
q_precision = quantize_config.get("precision")
|
273
|
-
quant_bits = q_precision.split("w")[1].split("a")[0]
|
274
|
-
os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
|
275
|
-
return RBLN_QUANT_BITS_ENV
|
276
|
-
return None
|
244
|
+
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
245
|
+
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
277
246
|
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
del os.environ[env_var_name]
|
247
|
+
# If the model wrapper supports rbln-custom-flash-attention
|
248
|
+
if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
|
249
|
+
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
282
250
|
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
def wrapper(*args, **kwargs):
|
287
|
-
quantize_config = kwargs.get("quantize_config")
|
288
|
-
quantize_env_var = cls.set_quantize_env(quantize_config)
|
289
|
-
try:
|
290
|
-
return func(*args, **kwargs)
|
291
|
-
finally:
|
292
|
-
cls.reset_quantize_env(quantize_env_var)
|
293
|
-
|
294
|
-
return wrapper
|
251
|
+
wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
|
252
|
+
|
253
|
+
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
295
254
|
|
296
255
|
@classmethod
|
297
256
|
@torch.inference_mode()
|
@@ -305,15 +264,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
305
264
|
@rbln_timer("JIT trace")
|
306
265
|
def get_scripted_model():
|
307
266
|
# This function is nested to dealloc the example inputs before compilation.
|
267
|
+
# FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
|
308
268
|
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
309
|
-
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=
|
310
|
-
|
311
|
-
batch_index = 3
|
312
|
-
dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
269
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
313
270
|
|
271
|
+
wrapped_model.phase = "prefill"
|
314
272
|
prefill_scripted_model = torch.jit.trace(
|
315
273
|
wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
|
316
274
|
)
|
275
|
+
wrapped_model.phase = "decode"
|
317
276
|
dec_scripted_model = torch.jit.trace(
|
318
277
|
wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
|
319
278
|
)
|
@@ -336,6 +295,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
336
295
|
prefill_ir, dec_ir = scripted_model_to_ir()
|
337
296
|
# Caching prefill_decoder/decoder I/O
|
338
297
|
cache_index_offset = 5
|
298
|
+
|
339
299
|
connections = [
|
340
300
|
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
341
301
|
for i in range(model.config.num_hidden_layers * 2)
|
@@ -344,7 +304,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
344
304
|
# Extract quantize_config from rbln_config
|
345
305
|
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
346
306
|
|
347
|
-
@
|
307
|
+
@QuantizationManager.with_quantization_env
|
348
308
|
def compile_model(*args, **kwargs):
|
349
309
|
# Remove quantize_config from kwargs
|
350
310
|
kwargs.pop("quantize_config", None)
|
@@ -374,10 +334,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
374
334
|
) -> RBLNConfig:
|
375
335
|
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
376
336
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
377
|
-
rbln_quantization = rbln_kwargs.get("quantization", None)
|
378
337
|
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
379
|
-
|
380
|
-
rbln_quantization = cls.validate_quantization_config(rbln_quantization)
|
338
|
+
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
381
339
|
|
382
340
|
prefill_chunk_size = 128
|
383
341
|
if rbln_max_seq_len is None:
|
@@ -552,8 +510,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
552
510
|
cache_position: Optional[torch.Tensor] = None,
|
553
511
|
attention_mask: Optional[torch.LongTensor] = None,
|
554
512
|
generate_idx: Optional[torch.Tensor] = None,
|
555
|
-
# from llava_next forward args
|
556
|
-
batch_idx: Optional[int] = None,
|
557
513
|
**kwargs,
|
558
514
|
) -> Tuple[torch.FloatTensor]:
|
559
515
|
# prefll
|
@@ -575,7 +531,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
575
531
|
input_ids=input_tensor if inputs_embeds is None else None,
|
576
532
|
inputs_embeds=input_tensor if inputs_embeds is not None else None,
|
577
533
|
cache_position=cache_position,
|
578
|
-
batch_idx=b_idx
|
534
|
+
batch_idx=b_idx,
|
579
535
|
)
|
580
536
|
logits.append(logit)
|
581
537
|
logits = torch.cat(logits, dim=0)
|
@@ -676,11 +632,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
676
632
|
cache_position: torch.Tensor = None,
|
677
633
|
) -> torch.FloatTensor:
|
678
634
|
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
635
|
+
if input_tensors is None:
|
636
|
+
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
679
637
|
|
680
638
|
batch_size = input_tensors.shape[0]
|
639
|
+
if batch_size != self.batch_size:
|
640
|
+
raise RuntimeError(
|
641
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
642
|
+
)
|
643
|
+
|
644
|
+
if batch_size != cache_position.shape[0]:
|
645
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
681
646
|
|
682
647
|
for b_idx in range(batch_size):
|
683
648
|
decoding_step = cache_position[b_idx].item()
|
649
|
+
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
650
|
+
raise ValueError(
|
651
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
652
|
+
)
|
684
653
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
685
654
|
|
686
655
|
logits, _ = self.decoder(
|
@@ -693,31 +662,3 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
693
662
|
)
|
694
663
|
|
695
664
|
return logits
|
696
|
-
|
697
|
-
def vllm_forward(
|
698
|
-
self,
|
699
|
-
input_ids: torch.LongTensor = None,
|
700
|
-
inputs_embeds: torch.Tensor = None,
|
701
|
-
cache_position: torch.Tensor = None,
|
702
|
-
batch_idx: Optional[int] = None,
|
703
|
-
**kwargs,
|
704
|
-
) -> Tuple[torch.FloatTensor]:
|
705
|
-
# prefll
|
706
|
-
if cache_position.shape[-1] > 1:
|
707
|
-
logits = self._forward_prefill(
|
708
|
-
input_ids=input_ids,
|
709
|
-
inputs_embeds=inputs_embeds,
|
710
|
-
cache_position=cache_position,
|
711
|
-
batch_idx=batch_idx,
|
712
|
-
)
|
713
|
-
# decoder
|
714
|
-
else:
|
715
|
-
logits = self._forward_decoder(
|
716
|
-
input_ids=input_ids,
|
717
|
-
inputs_embeds=inputs_embeds,
|
718
|
-
cache_position=cache_position,
|
719
|
-
)
|
720
|
-
|
721
|
-
return RBLNDecoderOnlyOutput(
|
722
|
-
logits=logits,
|
723
|
-
)
|
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
|
27
27
|
from transformers import AutoModelForDepthEstimation
|
28
28
|
from transformers.modeling_outputs import DepthEstimatorOutput
|
29
29
|
|
30
|
-
from ....
|
30
|
+
from ....modeling import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|
@@ -20,53 +20,78 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
from typing import TYPE_CHECKING
|
23
24
|
|
25
|
+
import torch.nn as nn
|
24
26
|
|
25
|
-
from
|
27
|
+
from ....utils import logging
|
28
|
+
from ...models.decoderonly.decoderonly_architecture import (
|
26
29
|
DecoderOnlyAttention,
|
27
|
-
|
30
|
+
DecoderOnlyFlashAttention,
|
31
|
+
DecoderOnlyForCausalLM,
|
32
|
+
DecoderOnlyLayer,
|
28
33
|
DecoderOnlyModel,
|
29
34
|
DecoderOnlyWrapper,
|
30
35
|
)
|
31
36
|
|
32
37
|
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from transformers import PreTrainedModel as ExaoneForCausalLM
|
40
|
+
|
41
|
+
logger = logging.get_logger(__name__)
|
42
|
+
|
43
|
+
|
33
44
|
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
34
45
|
"""A wrapper class for the Exaone model with a language modeling head."""
|
35
46
|
|
36
|
-
def
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
47
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
|
48
|
+
new_layers = []
|
49
|
+
for layer in causal_lm.transformer.h:
|
50
|
+
if self.attn_impl == "eager":
|
51
|
+
new_self_attn = ExaoneAttention(layer.attn.attention)
|
52
|
+
elif self.attn_impl == "flash_attn":
|
53
|
+
new_self_attn = ExaoneFlashAttention(
|
54
|
+
layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
58
|
+
|
59
|
+
new_layer = ExaoneLayer(layer, new_self_attn)
|
60
|
+
new_layers.append(new_layer)
|
61
|
+
new_model = ExaoneModel(causal_lm.transformer, new_layers)
|
62
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
63
|
+
return new_causal_lm
|
64
|
+
|
65
|
+
|
66
|
+
class ExaoneModel(DecoderOnlyModel):
|
67
|
+
def get_embedding(self) -> nn.Embedding:
|
68
|
+
return self._original_mod.wte
|
69
|
+
|
70
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
71
|
+
return self._original_mod.ln_f
|
72
|
+
|
73
|
+
|
74
|
+
class ExaoneLayer(DecoderOnlyLayer):
|
75
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
76
|
+
return self._original_mod.ln_1
|
77
|
+
|
78
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
79
|
+
return self._original_mod.ln_2
|
80
|
+
|
81
|
+
|
82
|
+
class ExaoneAttention(DecoderOnlyAttention):
|
83
|
+
def __post_init__(self):
|
84
|
+
self.q_proj = self._original_mod.q_proj
|
85
|
+
self.k_proj = self._original_mod.k_proj
|
86
|
+
self.v_proj = self._original_mod.v_proj
|
87
|
+
self.o_proj = self._original_mod.out_proj
|
88
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
89
|
+
|
90
|
+
|
91
|
+
class ExaoneFlashAttention(DecoderOnlyFlashAttention):
|
92
|
+
def __post_init__(self):
|
93
|
+
self.q_proj = self._original_mod.q_proj
|
94
|
+
self.k_proj = self._original_mod.k_proj
|
95
|
+
self.v_proj = self._original_mod.v_proj
|
96
|
+
self.o_proj = self._original_mod.out_proj
|
97
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
@@ -21,21 +21,15 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import inspect
|
25
|
-
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
27
24
|
|
28
|
-
from
|
25
|
+
from transformers import AutoModelForCausalLM
|
26
|
+
|
27
|
+
from ....utils import logging
|
29
28
|
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
30
29
|
from .exaone_architecture import ExaoneForCausalLMWrapper
|
31
|
-
from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
|
32
30
|
|
33
31
|
|
34
|
-
logger = logging.
|
35
|
-
if TYPE_CHECKING:
|
36
|
-
from transformers import (
|
37
|
-
PreTrainedModel,
|
38
|
-
)
|
32
|
+
logger = logging.get_logger(__name__)
|
39
33
|
|
40
34
|
|
41
35
|
class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
@@ -52,25 +46,8 @@ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
52
46
|
|
53
47
|
"""
|
54
48
|
|
55
|
-
|
56
|
-
|
57
|
-
rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
|
58
|
-
return ExaoneForCausalLMWrapper(model, rbln_max_seq_len).eval()
|
59
|
-
|
60
|
-
def __getattr__(self, __name: str) -> Any:
|
61
|
-
"""This is the key method to implement RBLN-Exaone.
|
62
|
-
|
63
|
-
Returns:
|
64
|
-
Any: Exaone's corresponding method
|
65
|
-
"""
|
66
|
-
|
67
|
-
def redirect(func):
|
68
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
69
|
-
|
70
|
-
val = getattr(ExaoneForCausalLM, __name)
|
71
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
72
|
-
return redirect(val)
|
73
|
-
return val
|
49
|
+
_decoder_wrapper_cls = ExaoneForCausalLMWrapper
|
50
|
+
_hf_class = AutoModelForCausalLM
|
74
51
|
|
75
52
|
@classmethod
|
76
53
|
def from_pretrained(cls, *args, **kwargs):
|