optimum-rbln 0.2.1a0__py3-none-any.whl → 0.2.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. optimum/rbln/__init__.py +3 -10
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +1 -10
  4. optimum/rbln/diffusers/modeling_diffusers.py +1 -10
  5. optimum/rbln/diffusers/models/__init__.py +1 -10
  6. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -10
  7. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -10
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -10
  9. optimum/rbln/diffusers/models/controlnet.py +1 -10
  10. optimum/rbln/diffusers/models/transformers/__init__.py +1 -10
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -10
  12. optimum/rbln/diffusers/models/unets/__init__.py +1 -10
  13. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -10
  14. optimum/rbln/diffusers/pipelines/__init__.py +1 -10
  15. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +1 -10
  16. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -10
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +1 -10
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -10
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -10
  20. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -10
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -10
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -10
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -10
  24. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -10
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +1 -10
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -10
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -10
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -10
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -10
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -10
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -10
  32. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -10
  33. optimum/rbln/modeling.py +1 -10
  34. optimum/rbln/modeling_base.py +1 -10
  35. optimum/rbln/modeling_config.py +1 -10
  36. optimum/rbln/ops/__init__.py +1 -10
  37. optimum/rbln/ops/attn.py +9 -18
  38. optimum/rbln/ops/flash_attn.py +5 -14
  39. optimum/rbln/ops/kv_cache_update.py +1 -10
  40. optimum/rbln/transformers/__init__.py +5 -12
  41. optimum/rbln/transformers/modeling_alias.py +1 -14
  42. optimum/rbln/transformers/modeling_generic.py +40 -21
  43. optimum/rbln/transformers/modeling_rope_utils.py +28 -0
  44. optimum/rbln/transformers/models/__init__.py +3 -12
  45. optimum/rbln/transformers/models/auto/__init__.py +1 -10
  46. optimum/rbln/transformers/models/auto/auto_factory.py +1 -10
  47. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -10
  48. optimum/rbln/transformers/models/bart/__init__.py +1 -10
  49. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -10
  50. optimum/rbln/transformers/models/bart/modeling_bart.py +14 -13
  51. optimum/rbln/transformers/models/bert/__init__.py +2 -11
  52. optimum/rbln/transformers/models/bert/modeling_bert.py +23 -13
  53. optimum/rbln/transformers/models/clip/__init__.py +1 -10
  54. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -10
  55. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -10
  56. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +54 -69
  57. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +14 -14
  58. optimum/rbln/transformers/models/dpt/__init__.py +1 -10
  59. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -10
  60. optimum/rbln/transformers/models/exaone/__init__.py +1 -10
  61. optimum/rbln/transformers/models/exaone/exaone_architecture.py +1 -10
  62. optimum/rbln/transformers/models/exaone/modeling_exaone.py +1 -10
  63. optimum/rbln/transformers/models/gemma/__init__.py +1 -10
  64. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -10
  65. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -10
  66. optimum/rbln/transformers/models/gpt2/__init__.py +1 -10
  67. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +1 -10
  68. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -10
  69. optimum/rbln/transformers/models/llama/__init__.py +1 -10
  70. optimum/rbln/transformers/models/llama/llama_architecture.py +1 -10
  71. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -10
  72. optimum/rbln/transformers/models/llava_next/__init__.py +1 -10
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +95 -89
  74. optimum/rbln/transformers/models/midm/__init__.py +1 -10
  75. optimum/rbln/transformers/models/midm/midm_architecture.py +1 -10
  76. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -10
  77. optimum/rbln/transformers/models/mistral/__init__.py +1 -10
  78. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -10
  79. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -10
  80. optimum/rbln/transformers/models/phi/__init__.py +1 -10
  81. optimum/rbln/transformers/models/phi/modeling_phi.py +1 -10
  82. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -10
  83. optimum/rbln/transformers/models/qwen2/__init__.py +1 -10
  84. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +1 -10
  85. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +1 -10
  86. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -10
  87. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -10
  88. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +16 -42
  89. optimum/rbln/transformers/models/t5/__init__.py +1 -10
  90. optimum/rbln/transformers/models/t5/modeling_t5.py +14 -15
  91. optimum/rbln/transformers/models/t5/t5_architecture.py +30 -16
  92. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -10
  93. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -10
  94. optimum/rbln/transformers/models/whisper/__init__.py +1 -10
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +2 -11
  96. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -10
  97. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -10
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -10
  99. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +24 -12
  100. optimum/rbln/transformers/utils/rbln_quantization.py +6 -10
  101. optimum/rbln/utils/__init__.py +1 -10
  102. optimum/rbln/utils/decorator_utils.py +1 -10
  103. optimum/rbln/utils/hub.py +1 -10
  104. optimum/rbln/utils/import_utils.py +1 -10
  105. optimum/rbln/utils/logging.py +1 -10
  106. optimum/rbln/utils/model_utils.py +1 -10
  107. optimum/rbln/utils/runtime_utils.py +1 -10
  108. optimum/rbln/utils/save_utils.py +2 -10
  109. optimum/rbln/utils/submodule.py +1 -10
  110. {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/METADATA +6 -4
  111. optimum_rbln-0.2.1a2.dist-info/RECORD +114 -0
  112. optimum_rbln-0.2.1a2.dist-info/licenses/LICENSE +201 -0
  113. optimum_rbln-0.2.1a0.dist-info/RECORD +0 -114
  114. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +0 -288
  115. {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import Tuple
25
16
 
26
17
  import torch
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
17
 
@@ -67,23 +58,33 @@ class RBLNBartModel(RBLNModel):
67
58
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
68
59
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
69
60
 
61
+ signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
62
+
70
63
  if rbln_model_input_names is None:
71
64
  for tokenizer in preprocessors:
72
65
  if hasattr(tokenizer, "model_input_names"):
73
- rbln_model_input_names = tokenizer.model_input_names
66
+ rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
74
67
  # BartModel's forward() does not take token_type_ids as input.
75
68
  # (Added because some of the tokenizers includes 'token_type_ids')
76
69
  if "token_type_ids" in rbln_model_input_names:
77
70
  rbln_model_input_names.remove("token_type_ids")
71
+
72
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
73
+ if invalid_params:
74
+ raise ValueError(f"Invalid model input names: {invalid_params}")
78
75
  break
79
76
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
80
77
  rbln_model_input_names = cls.rbln_model_input_names
81
78
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
82
- input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
83
79
  raise ValueError(
84
80
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
85
- f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
81
+ f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(signature_params)})"
86
82
  )
83
+ else:
84
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
85
+ if invalid_params:
86
+ raise ValueError(f"Invalid model input names: {invalid_params}")
87
+ rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
87
88
 
88
89
  if rbln_batch_size is None:
89
90
  rbln_batch_size = 1
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
- from .modeling_bert import RBLNBertModel
15
+ from .modeling_bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  import logging
26
17
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
@@ -29,6 +20,7 @@ from transformers import PretrainedConfig
29
20
 
30
21
  from ....modeling import RBLNModel
31
22
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
23
+ from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForQuestionAnswering
32
24
 
33
25
 
34
26
  logger = logging.getLogger(__name__)
@@ -64,19 +56,29 @@ class RBLNBertModel(RBLNModel):
64
56
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
65
57
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
66
58
 
59
+ signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
60
+
67
61
  if rbln_model_input_names is None:
68
62
  for tokenizer in preprocessors:
69
63
  if hasattr(tokenizer, "model_input_names"):
70
- rbln_model_input_names = tokenizer.model_input_names
64
+ rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
65
+
66
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
67
+ if invalid_params:
68
+ raise ValueError(f"Invalid model input names: {invalid_params}")
71
69
  break
72
70
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
73
71
  rbln_model_input_names = cls.rbln_model_input_names
74
72
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
75
- input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
76
73
  raise ValueError(
77
74
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
78
- f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
75
+ f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(signature_params)})"
79
76
  )
77
+ else:
78
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
79
+ if invalid_params:
80
+ raise ValueError(f"Invalid model input names: {invalid_params}")
81
+ rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
80
82
 
81
83
  if rbln_batch_size is None:
82
84
  rbln_batch_size = 1
@@ -96,3 +98,11 @@ class RBLNBertModel(RBLNModel):
96
98
 
97
99
  rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
98
100
  return rbln_config
101
+
102
+
103
+ class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
104
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
105
+
106
+
107
+ class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
108
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import logging
25
16
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import math
25
16
  from typing import List, Optional, Tuple
26
17
 
@@ -553,15 +544,19 @@ class DecoderOnlyAttention(nn.Module):
553
544
  super().__init__()
554
545
  self._original_mod = self_attn
555
546
  self.layer_idx = self_attn.layer_idx
556
- self.num_heads = self._original_mod.num_heads
547
+ self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
548
+ self._original_mod.config, "num_attention_heads"
549
+ )
557
550
  self.head_dim = self._original_mod.head_dim
558
551
  self._phase = "prefill"
559
552
  self.scale = torch.tensor(self.get_attn_scale())
560
553
 
561
554
  if hasattr(self._original_mod, "num_key_value_heads"):
562
555
  self.num_key_value_heads = self._original_mod.num_key_value_heads
556
+ elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
557
+ self.num_key_value_heads = self._original_mod.config.num_key_value_heads
563
558
  else:
564
- self.num_key_value_heads = self._original_mod.num_heads
559
+ self.num_key_value_heads = self.num_heads
565
560
 
566
561
  self.attention = self.get_attention()
567
562
  self.__post_init__()
@@ -631,33 +626,21 @@ class DecoderOnlyAttention(nn.Module):
631
626
  if batch_size > 1 and self.phase == "prefill":
632
627
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
633
628
 
634
- # TODO(jongho): flash attn legacy. (clone)
635
- _seq_positions = seq_positions.clone().unsqueeze(1)
636
-
637
- _key_states = []
638
- _value_states = []
639
- _attn_outputs = []
640
- for b in range(batch_size):
641
- seq_position = _seq_positions[b][0]
642
- attn_output, key_state, value_state = self.attention(
643
- query_states[b].unsqueeze(0),
644
- key_states[b].unsqueeze(0),
645
- value_states[b].unsqueeze(0),
646
- attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
647
- past_key_state=past_key_values[self.layer_idx][0],
648
- past_value_state=past_key_values[self.layer_idx][1],
649
- batch_position=b if self.phase == "decode" else batch_position,
650
- seq_position=seq_position,
651
- scale=self.scale,
652
- )
653
- _key_states.append(key_state)
654
- _value_states.append(value_state)
655
- _attn_outputs.append(attn_output)
656
- key_states = torch.cat(_key_states, dim=0)
657
- value_states = torch.cat(_value_states, dim=0)
658
- attn_outputs = torch.cat(_attn_outputs, dim=0)
659
-
660
- attn_outputs = self.o_proj(attn_outputs)
629
+ attn_output, key_state, value_state = self.attention(
630
+ query_states,
631
+ key_states,
632
+ value_states,
633
+ attention_mask,
634
+ past_key_state=past_key_values[self.layer_idx][0],
635
+ past_value_state=past_key_values[self.layer_idx][1],
636
+ batch_position=None if self.phase == "decode" else batch_position,
637
+ seq_position=seq_positions,
638
+ scale=self.scale,
639
+ )
640
+ key_states = key_state
641
+ value_states = value_state
642
+
643
+ attn_outputs = self.o_proj(attn_output)
661
644
  past_key_values[self.layer_idx] = key_states, value_states
662
645
  return attn_outputs, past_key_values
663
646
 
@@ -703,8 +686,13 @@ class AttentionOp(nn.Module):
703
686
  value_state = value_state.unsqueeze(2)
704
687
  attn_mask = attn_mask.unsqueeze(2)
705
688
 
689
+ if self.phase == "decode":
690
+ batch_size = key_state.shape[0]
691
+ else:
692
+ batch_size = 1
693
+
706
694
  query_state = query_state.view(
707
- 1,
695
+ batch_size,
708
696
  self.num_key_value_heads,
709
697
  self.num_heads // self.num_key_value_heads,
710
698
  -1, # seq len
@@ -736,9 +724,9 @@ class AttentionOp(nn.Module):
736
724
  scale,
737
725
  )
738
726
 
739
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
727
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
740
728
  attn_output = attn_output.transpose(1, 2).contiguous()
741
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
729
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
742
730
 
743
731
  return attn_output, key_state.squeeze(2), value_state.squeeze(2)
744
732
 
@@ -867,31 +855,23 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
867
855
  if cos is not None and sin is not None:
868
856
  query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
869
857
 
870
- _key_states = []
871
- _value_states = []
872
- _attn_outputs = []
873
- for b in range(batch_size):
874
- seq_position = seq_positions[b][0] # FIXME: Remove take-take pattern matching
875
- attn_output, key_state, value_state = self.attention(
876
- query_states[b].unsqueeze(0),
877
- key_states[b].unsqueeze(0),
878
- value_states[b].unsqueeze(0),
879
- attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
880
- past_key_state=past_key_values[self.layer_idx][0],
881
- past_value_state=past_key_values[self.layer_idx][1],
882
- batch_position=b if self.phase == "decode" else batch_position,
883
- seq_position=seq_position,
884
- scale=self.scale,
885
- )
886
- _key_states.append(key_state)
887
- _value_states.append(value_state)
888
- _attn_outputs.append(attn_output)
889
- key_states = torch.cat(_key_states, dim=0)
890
- value_states = torch.cat(_value_states, dim=0)
891
- attn_outputs = torch.cat(_attn_outputs, dim=0)
892
-
893
- attn_outputs = self.o_proj(attn_outputs)
858
+ attn_output, key_state, value_state = self.attention(
859
+ query_states,
860
+ key_states,
861
+ value_states,
862
+ attention_mask,
863
+ past_key_state=past_key_values[self.layer_idx][0],
864
+ past_value_state=past_key_values[self.layer_idx][1],
865
+ batch_position=None if self.phase == "decode" else batch_position,
866
+ seq_position=seq_positions,
867
+ scale=self.scale,
868
+ )
869
+ key_states = key_state
870
+ value_states = value_state
871
+
872
+ attn_outputs = self.o_proj(attn_output)
894
873
  past_key_values[self.layer_idx] = key_states, value_states
874
+
895
875
  return attn_outputs, past_key_values
896
876
 
897
877
 
@@ -917,8 +897,13 @@ class FlashAttentionOp(AttentionOp):
917
897
  value_state = value_state.unsqueeze(2)
918
898
  attn_mask = attn_mask.unsqueeze(2)
919
899
 
900
+ if self.phase == "decode":
901
+ batch_size = key_state.shape[0]
902
+ else:
903
+ batch_size = 1
904
+
920
905
  query_state = query_state.view(
921
- 1,
906
+ batch_size,
922
907
  self.num_key_value_heads,
923
908
  self.num_heads // self.num_key_value_heads,
924
909
  -1, # seq len
@@ -952,8 +937,8 @@ class FlashAttentionOp(AttentionOp):
952
937
  )
953
938
 
954
939
  # reshape for removing repeat_kv
955
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
940
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
956
941
  attn_output = attn_output.transpose(1, 2).contiguous()
957
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
942
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
958
943
 
959
944
  return attn_output, key_state, value_state
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  from dataclasses import dataclass
26
17
  from pathlib import Path
@@ -218,14 +209,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
218
209
 
219
210
  @classmethod
220
211
  def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
212
+ logger.debug("Loading the LLM model to the CPU.") # TODO(jongho): Remove.
213
+
221
214
  rbln_kwargs = kwargs.get("rbln_kwargs", {})
222
215
  rbln_quantization = rbln_kwargs.get("quantization", None)
223
-
224
216
  if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
225
217
  model = cls.get_quantized_model(*args, **kwargs)
226
218
  else:
227
219
  model = super().get_pytorch_model(*args, **kwargs)
228
220
 
221
+ logger.debug("Loaded the LLM model to the CPU.")
229
222
  return model
230
223
 
231
224
  @classmethod
@@ -297,8 +290,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
297
290
  rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
298
291
  rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
299
292
  rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
293
+ rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
294
+
295
+ if rbln_prefill_chunk_size is None:
296
+ rbln_prefill_chunk_size = 128
297
+ elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
298
+ raise ValueError(
299
+ f"Invalid rbln_prefill_chunk_size: {rbln_prefill_chunk_size}. It must be a nonzero multiple of 64."
300
+ )
300
301
 
301
- prefill_chunk_size = 128
302
302
  if rbln_max_seq_len is None:
303
303
  rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
304
304
  model_config, "n_positions", None
@@ -369,7 +369,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
369
369
 
370
370
  prefill_input_info = get_input_info(
371
371
  batch_size=1,
372
- query_length=prefill_chunk_size,
372
+ query_length=rbln_prefill_chunk_size,
373
373
  use_inputs_embeds=rbln_use_inputs_embeds,
374
374
  hidden_size=hidden_size,
375
375
  )
@@ -393,7 +393,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
393
393
  {
394
394
  "max_seq_len": rbln_max_seq_len,
395
395
  "batch_size": rbln_batch_size,
396
- "prefill_chunk_size": prefill_chunk_size,
396
+ "prefill_chunk_size": rbln_prefill_chunk_size,
397
397
  "use_inputs_embeds": rbln_use_inputs_embeds,
398
398
  "kvcache_partition_len": rbln_kvcache_partition_len,
399
399
  "attn_impl": rbln_attn_impl,
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_dpt import RBLNDPTForDepthEstimation
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import logging
25
16
  from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import os
25
16
  from os import environ
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import TYPE_CHECKING
25
16
 
26
17
  import torch.nn as nn
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
 
25
16
  from transformers import AutoModelForCausalLM
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_gemma import RBLNGemmaForCausalLM
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import TYPE_CHECKING
25
16
 
26
17
  from ...models.decoderonly.decoderonly_architecture import (