optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +26 -33
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +4 -0
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
- optimum/rbln/diffusers/models/__init__.py +2 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
- optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
- optimum/rbln/diffusers/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
- optimum/rbln/modeling.py +13 -347
- optimum/rbln/modeling_base.py +24 -4
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -0
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
- optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
- optimum/rbln/utils/decorator_utils.py +51 -15
- optimum/rbln/utils/import_utils.py +8 -1
- optimum/rbln/utils/logging.py +38 -1
- optimum/rbln/utils/model_utils.py +0 -1
- optimum/rbln/utils/runtime_utils.py +9 -3
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +23 -0
- optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
- optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.15.dist-info/METADATA +0 -106
- optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -20,6 +20,7 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
|
23
24
|
import inspect
|
24
25
|
from dataclasses import dataclass
|
25
26
|
from pathlib import Path
|
@@ -27,28 +28,26 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
|
27
28
|
|
28
29
|
import rebel
|
29
30
|
import torch
|
31
|
+
from rebel.compile_context import CompileContext
|
30
32
|
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
31
33
|
from transformers.modeling_utils import no_init_weights
|
32
34
|
from transformers.utils import ModelOutput
|
33
35
|
|
34
36
|
from ....modeling import RBLNModel
|
35
|
-
from ....modeling_config import
|
37
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
36
38
|
from ....utils.logging import get_logger
|
37
39
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
38
|
-
from ....utils.timer_utils import rbln_timer
|
39
40
|
from ...utils.rbln_quantization import QuantizationManager
|
40
|
-
from .decoderonly_architecture import
|
41
|
+
from .decoderonly_architecture import (
|
42
|
+
DecoderOnlyWrapper,
|
43
|
+
validate_attention_method,
|
44
|
+
)
|
41
45
|
|
42
46
|
|
43
47
|
logger = get_logger()
|
44
48
|
|
45
49
|
if TYPE_CHECKING:
|
46
|
-
from transformers import
|
47
|
-
AutoFeatureExtractor,
|
48
|
-
AutoProcessor,
|
49
|
-
AutoTokenizer,
|
50
|
-
PretrainedConfig,
|
51
|
-
)
|
50
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
52
51
|
|
53
52
|
|
54
53
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
@@ -60,32 +59,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
60
59
|
inputs_embeds: torch.Tensor,
|
61
60
|
attention_mask: torch.Tensor,
|
62
61
|
cache_position: torch.Tensor,
|
63
|
-
batch_position: torch.Tensor,
|
64
|
-
query_idx: torch.Tensor,
|
65
62
|
**kwargs,
|
66
63
|
):
|
67
64
|
if inputs_embeds is None:
|
68
65
|
inp = input_ids
|
69
66
|
if self.embed_tokens is not None:
|
70
67
|
inp = self.embed_tokens(inp)
|
71
|
-
|
72
|
-
return super().forward(
|
73
|
-
inp,
|
74
|
-
attention_mask,
|
75
|
-
cache_position,
|
76
|
-
batch_position,
|
77
|
-
query_idx,
|
78
|
-
**kwargs,
|
79
|
-
)
|
80
68
|
else:
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
69
|
+
inp = inputs_embeds
|
70
|
+
|
71
|
+
return super().forward(
|
72
|
+
inp,
|
73
|
+
attention_mask,
|
74
|
+
cache_position,
|
75
|
+
**kwargs,
|
76
|
+
)
|
89
77
|
|
90
78
|
|
91
79
|
@dataclass
|
@@ -243,11 +231,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
243
231
|
@classmethod
|
244
232
|
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
245
233
|
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
246
|
-
|
247
|
-
|
248
|
-
if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
|
249
|
-
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
250
|
-
|
234
|
+
wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
|
235
|
+
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
251
236
|
wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
|
252
237
|
|
253
238
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
@@ -258,72 +243,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
258
243
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
259
244
|
|
260
245
|
rbln_compile_configs = rbln_config.compile_cfgs
|
261
|
-
|
262
|
-
|
246
|
+
prefill_compile_config = rbln_compile_configs[0]
|
247
|
+
dec_compile_config = rbln_compile_configs[1]
|
263
248
|
|
264
|
-
|
265
|
-
def get_scripted_model():
|
266
|
-
# This function is nested to dealloc the example inputs before compilation.
|
267
|
-
# FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
|
268
|
-
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
269
|
-
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
249
|
+
context = CompileContext(use_weight_sharing=True)
|
270
250
|
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
)
|
275
|
-
wrapped_model.phase = "decode"
|
276
|
-
dec_scripted_model = torch.jit.trace(
|
277
|
-
wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
|
278
|
-
)
|
279
|
-
return prefill_scripted_model, dec_scripted_model
|
280
|
-
|
281
|
-
prefill_scripted_model, dec_scripted_model = get_scripted_model()
|
251
|
+
# Here we use meta tensor, for the memory efficiency.
|
252
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
253
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
282
254
|
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
dec_ir = rebel.torchscript_to_ir(
|
290
|
-
dec_scripted_model,
|
291
|
-
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
292
|
-
)
|
293
|
-
return prefill_ir, dec_ir
|
255
|
+
# Mark static tensors (self kv states)
|
256
|
+
static_tensors = {}
|
257
|
+
for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
|
258
|
+
if "past_key_values" in name:
|
259
|
+
static_tensors[name] = tensor
|
260
|
+
context.mark_static_address(tensor)
|
294
261
|
|
295
|
-
|
296
|
-
# Caching prefill_decoder/decoder I/O
|
297
|
-
cache_index_offset = 5
|
262
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
298
263
|
|
299
|
-
connections = [
|
300
|
-
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
301
|
-
for i in range(model.config.num_hidden_layers * 2)
|
302
|
-
]
|
303
|
-
|
304
|
-
# Extract quantize_config from rbln_config
|
305
264
|
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
306
265
|
|
307
266
|
@QuantizationManager.with_quantization_env
|
308
267
|
def compile_model(*args, **kwargs):
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
268
|
+
wrapped_model.phase = "prefill"
|
269
|
+
compiled_prefill = RBLNModel.compile(
|
270
|
+
wrapped_model,
|
271
|
+
prefill_compile_config,
|
272
|
+
example_inputs=prefill_example_inputs,
|
273
|
+
compile_context=context,
|
274
|
+
)
|
275
|
+
|
276
|
+
wrapped_model.phase = "decode"
|
277
|
+
compiled_decoder = RBLNModel.compile(
|
278
|
+
wrapped_model,
|
279
|
+
dec_compile_config,
|
280
|
+
example_inputs=dec_example_inputs,
|
281
|
+
compile_context=context,
|
282
|
+
)
|
283
|
+
return {"prefill": compiled_prefill, "decoder": compiled_decoder}
|
325
284
|
|
326
|
-
return
|
285
|
+
return compile_model(quantize_config=quantize_config)
|
327
286
|
|
328
287
|
@classmethod
|
329
288
|
def _get_rbln_config(
|
@@ -335,6 +294,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
335
294
|
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
336
295
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
337
296
|
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
297
|
+
rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
|
298
|
+
rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
|
338
299
|
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
339
300
|
|
340
301
|
prefill_chunk_size = 128
|
@@ -344,9 +305,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
344
305
|
)
|
345
306
|
if rbln_max_seq_len is None:
|
346
307
|
raise ValueError("`rbln_max_seq_len` should be specified.")
|
308
|
+
|
347
309
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
348
310
|
rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
|
349
311
|
|
312
|
+
rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
|
313
|
+
rbln_attn_impl=rbln_attn_impl,
|
314
|
+
rbln_kvcache_partition_len=rbln_kvcache_partition_len,
|
315
|
+
rbln_max_seq_len=rbln_max_seq_len,
|
316
|
+
)
|
317
|
+
|
350
318
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
351
319
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
352
320
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
@@ -372,9 +340,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
372
340
|
[batch_size, query_length],
|
373
341
|
"int32",
|
374
342
|
),
|
375
|
-
("batch_position", [], "int16"),
|
376
|
-
("query_idx", [], "int16"),
|
377
343
|
]
|
344
|
+
if query_length > 1:
|
345
|
+
input_info.extend(
|
346
|
+
[
|
347
|
+
("batch_position", [], "int16"),
|
348
|
+
("query_position", [], "int16"),
|
349
|
+
]
|
350
|
+
)
|
378
351
|
|
379
352
|
input_info.extend(
|
380
353
|
[
|
@@ -407,12 +380,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
407
380
|
hidden_size=hidden_size,
|
408
381
|
)
|
409
382
|
|
410
|
-
|
411
|
-
|
383
|
+
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
384
|
+
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
412
385
|
|
413
386
|
rbln_config = RBLNConfig(
|
414
387
|
rbln_cls=cls.__name__,
|
415
|
-
compile_cfgs=[
|
388
|
+
compile_cfgs=[prefill_compile_config, dec_compile_config],
|
416
389
|
rbln_kwargs=rbln_kwargs,
|
417
390
|
)
|
418
391
|
|
@@ -422,6 +395,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
422
395
|
"batch_size": rbln_batch_size,
|
423
396
|
"prefill_chunk_size": prefill_chunk_size,
|
424
397
|
"use_inputs_embeds": rbln_use_inputs_embeds,
|
398
|
+
"kvcache_partition_len": rbln_kvcache_partition_len,
|
399
|
+
"attn_impl": rbln_attn_impl,
|
425
400
|
}
|
426
401
|
)
|
427
402
|
|
@@ -432,12 +407,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
432
407
|
|
433
408
|
@classmethod
|
434
409
|
def _create_runtimes(
|
435
|
-
cls,
|
410
|
+
cls,
|
411
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
412
|
+
rbln_device_map: Dict[str, int],
|
413
|
+
activate_profiler: Optional[bool] = None,
|
436
414
|
) -> List[rebel.Runtime]:
|
437
|
-
|
415
|
+
if any(model_name not in rbln_device_map for model_name in ["prefill", "decoder"]):
|
416
|
+
cls._raise_missing_compiled_file_error(["prefill", "decoder"])
|
417
|
+
|
438
418
|
return [
|
439
|
-
compiled_models[0].create_runtime(
|
440
|
-
|
419
|
+
compiled_models[0].create_runtime(
|
420
|
+
tensor_type="pt", device=rbln_device_map["prefill"], activate_profiler=activate_profiler
|
421
|
+
),
|
422
|
+
compiled_models[1].create_runtime(
|
423
|
+
tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
|
424
|
+
),
|
441
425
|
]
|
442
426
|
|
443
427
|
def get_decoder(self):
|
@@ -569,12 +553,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
569
553
|
],
|
570
554
|
dtype=torch.float32,
|
571
555
|
device="cpu",
|
572
|
-
)
|
573
|
-
torch.empty(size=[], dtype=torch.int16, device="cpu"),
|
556
|
+
)
|
574
557
|
]
|
575
558
|
|
576
559
|
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
577
560
|
query_length = input_tensors.shape[1]
|
561
|
+
if query_length > self.max_seq_len:
|
562
|
+
raise ValueError(
|
563
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
564
|
+
)
|
565
|
+
|
578
566
|
_attention_mask = self.prefill_attention_mask.clone()
|
579
567
|
|
580
568
|
for step in range(0, query_length, self.prefill_chunk_size):
|
@@ -607,15 +595,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
607
595
|
_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
608
596
|
_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
609
597
|
|
610
|
-
|
598
|
+
query_position = (query_length - 1) % self.prefill_chunk_size
|
611
599
|
|
612
|
-
logits
|
600
|
+
logits = self.prefill_decoder(
|
613
601
|
input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
|
614
602
|
inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
|
615
603
|
attention_mask=_attention_mask.contiguous(),
|
616
604
|
cache_position=_cache_position.contiguous(),
|
617
605
|
batch_position=torch.tensor(batch_idx, dtype=torch.int16),
|
618
|
-
|
606
|
+
query_position=torch.tensor(query_position, dtype=torch.int16),
|
619
607
|
out=out_buffers,
|
620
608
|
)
|
621
609
|
|
@@ -651,14 +639,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
651
639
|
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
652
640
|
)
|
653
641
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
654
|
-
|
655
|
-
logits, _ = self.decoder(
|
642
|
+
logits = self.decoder(
|
656
643
|
input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
|
657
644
|
inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
|
658
645
|
attention_mask=self.dec_attn_mask.contiguous(),
|
659
646
|
cache_position=cache_position.contiguous(),
|
660
|
-
batch_position=torch.tensor(0, dtype=torch.int16),
|
661
|
-
query_idx=torch.tensor(0, dtype=torch.int16),
|
662
647
|
)
|
663
648
|
|
664
649
|
return logits
|
@@ -20,6 +20,7 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
|
23
24
|
from typing import TYPE_CHECKING
|
24
25
|
|
25
26
|
import torch.nn as nn
|
@@ -58,7 +59,7 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
58
59
|
|
59
60
|
new_layer = ExaoneLayer(layer, new_self_attn)
|
60
61
|
new_layers.append(new_layer)
|
61
|
-
new_model = ExaoneModel(causal_lm.transformer, new_layers)
|
62
|
+
new_model = ExaoneModel(causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len)
|
62
63
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
63
64
|
return new_causal_lm
|
64
65
|
|
@@ -85,7 +86,6 @@ class ExaoneAttention(DecoderOnlyAttention):
|
|
85
86
|
self.k_proj = self._original_mod.k_proj
|
86
87
|
self.v_proj = self._original_mod.v_proj
|
87
88
|
self.o_proj = self._original_mod.out_proj
|
88
|
-
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
89
89
|
|
90
90
|
|
91
91
|
class ExaoneFlashAttention(DecoderOnlyFlashAttention):
|
@@ -94,4 +94,3 @@ class ExaoneFlashAttention(DecoderOnlyFlashAttention):
|
|
94
94
|
self.k_proj = self._original_mod.k_proj
|
95
95
|
self.v_proj = self._original_mod.v_proj
|
96
96
|
self.o_proj = self._original_mod.out_proj
|
97
|
-
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
@@ -51,7 +51,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
51
51
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
52
52
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
53
53
|
new_layers.append(new_layer)
|
54
|
-
new_model = GemmaModel(causal_lm.model, new_layers)
|
54
|
+
new_model = GemmaModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
|
55
55
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
56
56
|
return new_causal_lm
|
57
57
|
|
@@ -21,6 +21,7 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
import math
|
24
25
|
from typing import TYPE_CHECKING, Tuple
|
25
26
|
|
26
27
|
import torch
|
@@ -54,8 +55,6 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
54
55
|
|
55
56
|
|
56
57
|
class GPT2Model(DecoderOnlyModel):
|
57
|
-
mask_fmin = torch.finfo(torch.float32).min
|
58
|
-
|
59
58
|
def get_last_layernorm(self) -> nn.LayerNorm:
|
60
59
|
return self._original_mod.ln_f
|
61
60
|
|
@@ -79,16 +78,17 @@ class GPT2Attention(DecoderOnlyAttention):
|
|
79
78
|
self.c_attn = self._original_mod.c_attn
|
80
79
|
self.o_proj = self._original_mod.c_proj
|
81
80
|
self.split_size = self._original_mod.split_size
|
82
|
-
self.num_key_value_heads = self._original_mod.num_heads
|
83
81
|
|
84
82
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
85
83
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
86
84
|
return query_states, key_states, value_states
|
87
85
|
|
88
|
-
def
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
86
|
+
def get_attn_scale(self):
|
87
|
+
scale = 1.0
|
88
|
+
if self._original_mod.scale_attn_weights:
|
89
|
+
scale /= math.sqrt(self.head_dim)
|
90
|
+
|
91
|
+
if self._original_mod.scale_attn_by_inverse_layer_idx:
|
92
|
+
scale /= 1 + self.layer_idx
|
93
|
+
|
94
|
+
return scale
|
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
from ....utils import logging
|
25
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
|
-
from .gpt2_architecture import GPT2Wrapper
|
26
|
+
from .gpt2_architecture import GPT2Wrapper
|
27
27
|
|
28
28
|
|
29
29
|
logger = logging.get_logger(__name__)
|
@@ -21,12 +21,12 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
import math
|
24
25
|
from typing import TYPE_CHECKING, Tuple
|
25
26
|
|
26
27
|
import torch
|
27
28
|
import torch.nn as nn
|
28
29
|
|
29
|
-
from ....transformers.models.decoderonly.decoderonly_architecture import rotate_half
|
30
30
|
from ..decoderonly.decoderonly_architecture import (
|
31
31
|
DecoderOnlyAttention,
|
32
32
|
DecoderOnlyForCausalLM,
|
@@ -34,6 +34,7 @@ from ..decoderonly.decoderonly_architecture import (
|
|
34
34
|
DecoderOnlyModel,
|
35
35
|
DecoderOnlyWrapper,
|
36
36
|
apply_rotary_pos_emb_partial,
|
37
|
+
rotate_half,
|
37
38
|
)
|
38
39
|
|
39
40
|
|
@@ -77,8 +78,6 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
|
77
78
|
|
78
79
|
|
79
80
|
class MidmModel(DecoderOnlyModel):
|
80
|
-
mask_fmin = -10000.0
|
81
|
-
|
82
81
|
def get_layernorm1p(self, module: nn.LayerNorm):
|
83
82
|
def layernorm1p(input: torch.Tensor):
|
84
83
|
"""Applies Layer Normalization with a slight modification on the weights."""
|
@@ -135,14 +134,15 @@ class MidmAttention(DecoderOnlyAttention):
|
|
135
134
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
136
135
|
return query_states, key_states, value_states
|
137
136
|
|
138
|
-
def
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
137
|
+
def get_attn_scale(self):
|
138
|
+
scale = 1.0
|
139
|
+
if self._original_mod.scale_attn_weights:
|
140
|
+
scale /= math.sqrt(self.head_dim)
|
141
|
+
|
142
|
+
if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
|
143
|
+
scale /= 1 + self.layer_idx
|
144
|
+
|
145
|
+
return scale
|
146
146
|
|
147
147
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
148
148
|
return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
|
@@ -65,7 +65,6 @@ class PhiAttention(DecoderOnlyAttention):
|
|
65
65
|
self.o_proj = self._original_mod.dense
|
66
66
|
self.qk_layernorm = self._original_mod.qk_layernorm
|
67
67
|
self.rotary_ndims = self._original_mod.rotary_ndims
|
68
|
-
self.num_key_value_heads = self.num_heads
|
69
68
|
|
70
69
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
71
70
|
query_states = self.q_proj(hidden_states)
|
@@ -90,7 +89,7 @@ class PhiLayer(DecoderOnlyLayer):
|
|
90
89
|
self,
|
91
90
|
hidden_states: torch.Tensor,
|
92
91
|
attention_mask: torch.Tensor,
|
93
|
-
|
92
|
+
seq_positions: torch.LongTensor,
|
94
93
|
batch_position: torch.Tensor,
|
95
94
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
96
95
|
cos: Optional[torch.Tensor] = None,
|
@@ -103,7 +102,7 @@ class PhiLayer(DecoderOnlyLayer):
|
|
103
102
|
attn_outputs, present_key_values = self.self_attn(
|
104
103
|
hidden_states=hidden_states,
|
105
104
|
attention_mask=attention_mask,
|
106
|
-
|
105
|
+
seq_positions=seq_positions,
|
107
106
|
batch_position=batch_position,
|
108
107
|
past_key_values=past_key_values,
|
109
108
|
cos=cos,
|