optimum-rbln 0.2.1a0__py3-none-any.whl → 0.2.1a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +3 -10
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +1 -10
- optimum/rbln/diffusers/modeling_diffusers.py +1 -10
- optimum/rbln/diffusers/models/__init__.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -10
- optimum/rbln/diffusers/models/controlnet.py +1 -10
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -10
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -10
- optimum/rbln/diffusers/models/unets/__init__.py +1 -10
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -10
- optimum/rbln/diffusers/pipelines/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -10
- optimum/rbln/modeling.py +1 -10
- optimum/rbln/modeling_base.py +1 -10
- optimum/rbln/modeling_config.py +1 -10
- optimum/rbln/ops/__init__.py +1 -10
- optimum/rbln/ops/attn.py +9 -18
- optimum/rbln/ops/flash_attn.py +5 -14
- optimum/rbln/ops/kv_cache_update.py +1 -10
- optimum/rbln/transformers/__init__.py +5 -12
- optimum/rbln/transformers/modeling_alias.py +1 -14
- optimum/rbln/transformers/modeling_generic.py +40 -21
- optimum/rbln/transformers/modeling_rope_utils.py +28 -0
- optimum/rbln/transformers/models/__init__.py +3 -12
- optimum/rbln/transformers/models/auto/__init__.py +1 -10
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -10
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -10
- optimum/rbln/transformers/models/bart/__init__.py +1 -10
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -10
- optimum/rbln/transformers/models/bart/modeling_bart.py +14 -13
- optimum/rbln/transformers/models/bert/__init__.py +2 -11
- optimum/rbln/transformers/models/bert/modeling_bert.py +23 -13
- optimum/rbln/transformers/models/clip/__init__.py +1 -10
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -10
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +54 -69
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +14 -14
- optimum/rbln/transformers/models/dpt/__init__.py +1 -10
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -10
- optimum/rbln/transformers/models/exaone/__init__.py +1 -10
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +1 -10
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +1 -10
- optimum/rbln/transformers/models/gemma/__init__.py +1 -10
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -10
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -10
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -10
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +1 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -10
- optimum/rbln/transformers/models/llama/__init__.py +1 -10
- optimum/rbln/transformers/models/llama/llama_architecture.py +1 -10
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -10
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -10
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +95 -89
- optimum/rbln/transformers/models/midm/__init__.py +1 -10
- optimum/rbln/transformers/models/midm/midm_architecture.py +1 -10
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -10
- optimum/rbln/transformers/models/mistral/__init__.py +1 -10
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -10
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -10
- optimum/rbln/transformers/models/phi/__init__.py +1 -10
- optimum/rbln/transformers/models/phi/modeling_phi.py +1 -10
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -10
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -10
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +1 -10
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +1 -10
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -10
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -10
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +16 -42
- optimum/rbln/transformers/models/t5/__init__.py +1 -10
- optimum/rbln/transformers/models/t5/modeling_t5.py +14 -15
- optimum/rbln/transformers/models/t5/t5_architecture.py +30 -16
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -10
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -10
- optimum/rbln/transformers/models/whisper/__init__.py +1 -10
- optimum/rbln/transformers/models/whisper/generation_whisper.py +2 -11
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -10
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -10
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +24 -12
- optimum/rbln/transformers/utils/rbln_quantization.py +6 -10
- optimum/rbln/utils/__init__.py +1 -10
- optimum/rbln/utils/decorator_utils.py +1 -10
- optimum/rbln/utils/hub.py +1 -10
- optimum/rbln/utils/import_utils.py +1 -10
- optimum/rbln/utils/logging.py +1 -10
- optimum/rbln/utils/model_utils.py +1 -10
- optimum/rbln/utils/runtime_utils.py +1 -10
- optimum/rbln/utils/save_utils.py +2 -10
- optimum/rbln/utils/submodule.py +1 -10
- {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/METADATA +6 -4
- optimum_rbln-0.2.1a2.dist-info/RECORD +114 -0
- optimum_rbln-0.2.1a2.dist-info/licenses/LICENSE +201 -0
- optimum_rbln-0.2.1a0.dist-info/RECORD +0 -114
- optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +0 -288
- {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from ....utils import logging
|
25
16
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
17
|
from .phi_architecture import PhiWrapper
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import TYPE_CHECKING, Optional, Tuple
|
25
16
|
|
26
17
|
import torch
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_qwen2 import RBLNQwen2ForCausalLM
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from ....utils import logging
|
25
16
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
17
|
from .qwen2_architecture import QWEN2Wrapper
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
25
16
|
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_seq2seq import RBLNModelForSeq2SeqLM
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
import logging
|
26
17
|
from abc import ABC
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import Tuple
|
25
16
|
|
26
17
|
import torch
|
@@ -429,7 +420,7 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
429
420
|
pass
|
430
421
|
|
431
422
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
432
|
-
return tensor.view(bsz,
|
423
|
+
return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
433
424
|
|
434
425
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
435
426
|
"""Projects input hidden states into query, key, and value representations.
|
@@ -459,38 +450,21 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
459
450
|
key_states = self._shape(key_states, -1, bsz)
|
460
451
|
value_states = self._shape(value_states, -1, bsz)
|
461
452
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
query_state,
|
475
|
-
key_state,
|
476
|
-
value_state,
|
477
|
-
attn_mask,
|
478
|
-
past_key_state,
|
479
|
-
past_value_state,
|
480
|
-
cache_position[b_idx][0],
|
481
|
-
torch.tensor(1.0, dtype=torch.float32), # scale
|
482
|
-
)
|
483
|
-
|
484
|
-
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
485
|
-
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
486
|
-
|
487
|
-
all_key_states.append(key_state.squeeze(2))
|
488
|
-
all_value_states.append(value_state.squeeze(2))
|
489
|
-
all_attn_output.append(attn_output)
|
453
|
+
attn_output, key_states, value_states = self.attn_decode(
|
454
|
+
query_states,
|
455
|
+
key_states,
|
456
|
+
value_states,
|
457
|
+
attention_mask.unsqueeze(
|
458
|
+
2
|
459
|
+
), # Unsqueeze group axis since CustomKernel expects it for group query attention
|
460
|
+
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
461
|
+
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
462
|
+
cache_position.squeeze(1),
|
463
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
464
|
+
)
|
490
465
|
|
491
|
-
|
492
|
-
|
493
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
466
|
+
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
467
|
+
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
494
468
|
|
495
469
|
attn_output = self.out_proj(attn_output)
|
496
470
|
present_key_value = (key_states, value_states)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,20 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
26
17
|
|
27
18
|
import torch
|
28
|
-
import transformers
|
29
19
|
from transformers import (
|
30
20
|
AutoModelForTextEncoding,
|
31
21
|
PretrainedConfig,
|
@@ -130,20 +120,29 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
130
120
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
131
121
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
132
122
|
|
123
|
+
signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
|
124
|
+
|
133
125
|
if rbln_model_input_names is None:
|
134
126
|
for tokenizer in preprocessors:
|
135
127
|
if hasattr(tokenizer, "model_input_names"):
|
136
|
-
rbln_model_input_names = tokenizer.model_input_names
|
128
|
+
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
129
|
+
|
130
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
131
|
+
if invalid_params:
|
132
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
137
133
|
break
|
138
134
|
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
139
135
|
rbln_model_input_names = cls.rbln_model_input_names
|
140
136
|
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
141
|
-
original_model_class = getattr(transformers, model_config.architectures[0])
|
142
|
-
input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
|
143
137
|
raise ValueError(
|
144
138
|
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
145
|
-
f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(
|
139
|
+
f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(signature_params)})"
|
146
140
|
)
|
141
|
+
else:
|
142
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
143
|
+
if invalid_params:
|
144
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
145
|
+
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
147
146
|
|
148
147
|
if rbln_batch_size is None:
|
149
148
|
rbln_batch_size = 1
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import Tuple
|
25
16
|
|
26
17
|
import torch
|
@@ -156,6 +147,11 @@ class T5CrossAttention(nn.Module):
|
|
156
147
|
def __init__(self, attn):
|
157
148
|
super().__init__()
|
158
149
|
self.attn = attn
|
150
|
+
self.q = attn.q
|
151
|
+
self.o = attn.o
|
152
|
+
self.n_heads = attn.n_heads
|
153
|
+
self.key_value_proj_dim = attn.key_value_proj_dim
|
154
|
+
self.inner_dim = attn.inner_dim
|
159
155
|
|
160
156
|
def forward(
|
161
157
|
self,
|
@@ -164,9 +160,27 @@ class T5CrossAttention(nn.Module):
|
|
164
160
|
attention_mask: torch.Tensor = None,
|
165
161
|
key_value_states: torch.Tensor = None,
|
166
162
|
):
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
163
|
+
batch_size = hidden_states.shape[0]
|
164
|
+
|
165
|
+
query_states = self.q(hidden_states)
|
166
|
+
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
167
|
+
|
168
|
+
# reuse k,v, cross_attentions
|
169
|
+
key_states = past_key_value[0]
|
170
|
+
value_states = past_key_value[1]
|
171
|
+
|
172
|
+
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
173
|
+
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
174
|
+
scores += attention_mask
|
175
|
+
|
176
|
+
# (batch_size, n_heads, seq_length, key_length)
|
177
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
178
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
179
|
+
|
180
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
181
|
+
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
|
182
|
+
attn_output = self.o(attn_output)
|
183
|
+
|
184
|
+
outputs = (attn_output, past_key_value)
|
185
|
+
|
186
|
+
return outputs
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import logging
|
25
16
|
from typing import TYPE_CHECKING, Any, Dict, Union
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_whisper import RBLNWhisperForConditionalGeneration
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Copyright
|
15
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
16
16
|
|
17
17
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
18
18
|
# you may not use this file except in compliance with the License.
|
@@ -26,15 +26,6 @@
|
|
26
26
|
# See the License for the specific language governing permissions and
|
27
27
|
# limitations under the License.
|
28
28
|
|
29
|
-
# Portions of this software are licensed under the Apache License,
|
30
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
31
|
-
# additional information regarding copyright ownership.
|
32
|
-
|
33
|
-
# All other portions of this software, including proprietary code,
|
34
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
35
|
-
# copied, modified, or distributed without prior written permission
|
36
|
-
# from Rebellions Inc.
|
37
|
-
|
38
29
|
"""
|
39
30
|
Generation utilities for Whisper.
|
40
31
|
Modified from `transformers.models.whisper.generation_whisper.py`
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
import logging
|
26
17
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import Optional, Tuple, Union
|
25
16
|
|
26
17
|
import torch
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_xlm_roberta import RBLNXLMRobertaModel
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
15
|
+
import inspect
|
24
16
|
import logging
|
25
17
|
from typing import TYPE_CHECKING, Optional, Union
|
26
18
|
|
@@ -66,9 +58,29 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
66
58
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
67
59
|
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
68
60
|
|
61
|
+
signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
|
62
|
+
|
69
63
|
if rbln_model_input_names is None:
|
70
|
-
|
71
|
-
|
64
|
+
for tokenizer in preprocessors:
|
65
|
+
if hasattr(tokenizer, "model_input_names"):
|
66
|
+
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
67
|
+
|
68
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
69
|
+
if invalid_params:
|
70
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
71
|
+
break
|
72
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
73
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
74
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
75
|
+
raise ValueError(
|
76
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
77
|
+
f"and be sure to make the order of the inputs same as XLMRobertaModel forward() arguments like ({list(signature_params)})"
|
78
|
+
)
|
79
|
+
else:
|
80
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
81
|
+
if invalid_params:
|
82
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
83
|
+
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
72
84
|
|
73
85
|
if rbln_batch_size is None:
|
74
86
|
rbln_batch_size = 1
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import functools
|
25
16
|
import glob
|
26
17
|
import os
|
@@ -135,6 +126,8 @@ def update_layers_to_quantize(module: torch.nn.Module) -> None:
|
|
135
126
|
"""
|
136
127
|
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
137
128
|
"""
|
129
|
+
|
130
|
+
logger.debug("Updating layers to be quantized") # TODO(jongho): remove.
|
138
131
|
processed_layers = []
|
139
132
|
|
140
133
|
for name, layer in module.named_modules():
|
@@ -151,6 +144,7 @@ def load_weights(model, model_id, n_layer=None):
|
|
151
144
|
"""
|
152
145
|
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
153
146
|
"""
|
147
|
+
logger.debug("Loading the quantized weights into the CPU.") # TODO(jongho): remove.
|
154
148
|
|
155
149
|
model_params = dict(model.named_parameters(recurse=True))
|
156
150
|
model_buffers = dict(model.named_buffers(recurse=True))
|
@@ -172,6 +166,8 @@ def load_weights(model, model_id, n_layer=None):
|
|
172
166
|
elif key in model_buffers:
|
173
167
|
model_buffers[key].data.copy_(value)
|
174
168
|
|
169
|
+
logger.debug("Loaded the quantized weights into the CPU.")
|
170
|
+
|
175
171
|
|
176
172
|
def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
|
177
173
|
"""
|