optimum-rbln 0.2.0__py3-none-any.whl → 0.2.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. optimum/rbln/__init__.py +1 -10
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +1 -10
  4. optimum/rbln/diffusers/modeling_diffusers.py +1 -10
  5. optimum/rbln/diffusers/models/__init__.py +1 -10
  6. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -10
  7. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -10
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -10
  9. optimum/rbln/diffusers/models/controlnet.py +1 -10
  10. optimum/rbln/diffusers/models/transformers/__init__.py +1 -10
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -10
  12. optimum/rbln/diffusers/models/unets/__init__.py +1 -10
  13. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -10
  14. optimum/rbln/diffusers/pipelines/__init__.py +1 -10
  15. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +1 -10
  16. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -10
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +1 -10
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -10
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -10
  20. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -10
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -10
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -10
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -10
  24. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -10
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +1 -10
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -10
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -10
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -10
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -10
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -10
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -10
  32. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -10
  33. optimum/rbln/modeling.py +1 -10
  34. optimum/rbln/modeling_base.py +1 -10
  35. optimum/rbln/modeling_config.py +1 -10
  36. optimum/rbln/ops/__init__.py +1 -10
  37. optimum/rbln/ops/attn.py +5 -14
  38. optimum/rbln/ops/flash_attn.py +5 -14
  39. optimum/rbln/ops/kv_cache_update.py +1 -10
  40. optimum/rbln/transformers/__init__.py +3 -12
  41. optimum/rbln/transformers/modeling_alias.py +1 -14
  42. optimum/rbln/transformers/modeling_generic.py +40 -21
  43. optimum/rbln/transformers/modeling_rope_utils.py +28 -0
  44. optimum/rbln/transformers/models/__init__.py +3 -12
  45. optimum/rbln/transformers/models/auto/__init__.py +1 -10
  46. optimum/rbln/transformers/models/auto/auto_factory.py +1 -10
  47. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -10
  48. optimum/rbln/transformers/models/bart/__init__.py +1 -10
  49. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -10
  50. optimum/rbln/transformers/models/bart/modeling_bart.py +14 -13
  51. optimum/rbln/transformers/models/bert/__init__.py +2 -11
  52. optimum/rbln/transformers/models/bert/modeling_bert.py +19 -13
  53. optimum/rbln/transformers/models/clip/__init__.py +1 -10
  54. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -10
  55. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -10
  56. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +48 -67
  57. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +14 -14
  58. optimum/rbln/transformers/models/dpt/__init__.py +1 -10
  59. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -10
  60. optimum/rbln/transformers/models/exaone/__init__.py +1 -10
  61. optimum/rbln/transformers/models/exaone/exaone_architecture.py +1 -10
  62. optimum/rbln/transformers/models/exaone/modeling_exaone.py +1 -10
  63. optimum/rbln/transformers/models/gemma/__init__.py +1 -10
  64. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -10
  65. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -10
  66. optimum/rbln/transformers/models/gpt2/__init__.py +1 -10
  67. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +1 -10
  68. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -10
  69. optimum/rbln/transformers/models/llama/__init__.py +1 -10
  70. optimum/rbln/transformers/models/llama/llama_architecture.py +1 -10
  71. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -10
  72. optimum/rbln/transformers/models/llava_next/__init__.py +1 -10
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +95 -89
  74. optimum/rbln/transformers/models/midm/__init__.py +1 -10
  75. optimum/rbln/transformers/models/midm/midm_architecture.py +1 -10
  76. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -10
  77. optimum/rbln/transformers/models/mistral/__init__.py +1 -10
  78. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -10
  79. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -10
  80. optimum/rbln/transformers/models/phi/__init__.py +1 -10
  81. optimum/rbln/transformers/models/phi/modeling_phi.py +1 -10
  82. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -10
  83. optimum/rbln/transformers/models/qwen2/__init__.py +1 -10
  84. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +1 -10
  85. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +1 -10
  86. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -10
  87. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -10
  88. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -10
  89. optimum/rbln/transformers/models/t5/__init__.py +1 -10
  90. optimum/rbln/transformers/models/t5/modeling_t5.py +14 -15
  91. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -10
  92. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -10
  93. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -10
  94. optimum/rbln/transformers/models/whisper/__init__.py +1 -10
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +2 -11
  96. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -10
  97. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -10
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -10
  99. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +24 -12
  100. optimum/rbln/transformers/utils/rbln_quantization.py +7 -11
  101. optimum/rbln/utils/__init__.py +1 -10
  102. optimum/rbln/utils/decorator_utils.py +1 -10
  103. optimum/rbln/utils/hub.py +1 -10
  104. optimum/rbln/utils/import_utils.py +2 -11
  105. optimum/rbln/utils/logging.py +2 -11
  106. optimum/rbln/utils/model_utils.py +1 -10
  107. optimum/rbln/utils/runtime_utils.py +1 -10
  108. optimum/rbln/utils/save_utils.py +2 -10
  109. optimum/rbln/utils/submodule.py +1 -10
  110. {optimum_rbln-0.2.0.dist-info → optimum_rbln-0.2.1a1.dist-info}/METADATA +11 -5
  111. optimum_rbln-0.2.1a1.dist-info/RECORD +114 -0
  112. optimum_rbln-0.2.1a1.dist-info/licenses/LICENSE +201 -0
  113. optimum_rbln-0.2.0.dist-info/RECORD +0 -114
  114. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +0 -288
  115. {optimum_rbln-0.2.0.dist-info → optimum_rbln-0.2.1a1.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
17
 
@@ -67,23 +58,33 @@ class RBLNBartModel(RBLNModel):
67
58
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
68
59
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
69
60
 
61
+ signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
62
+
70
63
  if rbln_model_input_names is None:
71
64
  for tokenizer in preprocessors:
72
65
  if hasattr(tokenizer, "model_input_names"):
73
- rbln_model_input_names = tokenizer.model_input_names
66
+ rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
74
67
  # BartModel's forward() does not take token_type_ids as input.
75
68
  # (Added because some of the tokenizers includes 'token_type_ids')
76
69
  if "token_type_ids" in rbln_model_input_names:
77
70
  rbln_model_input_names.remove("token_type_ids")
71
+
72
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
73
+ if invalid_params:
74
+ raise ValueError(f"Invalid model input names: {invalid_params}")
78
75
  break
79
76
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
80
77
  rbln_model_input_names = cls.rbln_model_input_names
81
78
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
82
- input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
83
79
  raise ValueError(
84
80
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
85
- f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
81
+ f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(signature_params)})"
86
82
  )
83
+ else:
84
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
85
+ if invalid_params:
86
+ raise ValueError(f"Invalid model input names: {invalid_params}")
87
+ rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
87
88
 
88
89
  if rbln_batch_size is None:
89
90
  rbln_batch_size = 1
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
- from .modeling_bert import RBLNBertModel
15
+ from .modeling_bert import RBLNBertForQuestionAnswering, RBLNBertModel
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  import logging
26
17
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
@@ -29,6 +20,7 @@ from transformers import PretrainedConfig
29
20
 
30
21
  from ....modeling import RBLNModel
31
22
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
23
+ from ...modeling_generic import RBLNModelForQuestionAnswering
32
24
 
33
25
 
34
26
  logger = logging.getLogger(__name__)
@@ -64,19 +56,29 @@ class RBLNBertModel(RBLNModel):
64
56
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
65
57
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
66
58
 
59
+ signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
60
+
67
61
  if rbln_model_input_names is None:
68
62
  for tokenizer in preprocessors:
69
63
  if hasattr(tokenizer, "model_input_names"):
70
- rbln_model_input_names = tokenizer.model_input_names
64
+ rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
65
+
66
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
67
+ if invalid_params:
68
+ raise ValueError(f"Invalid model input names: {invalid_params}")
71
69
  break
72
70
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
73
71
  rbln_model_input_names = cls.rbln_model_input_names
74
72
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
75
- input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
76
73
  raise ValueError(
77
74
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
78
- f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
75
+ f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(signature_params)})"
79
76
  )
77
+ else:
78
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
79
+ if invalid_params:
80
+ raise ValueError(f"Invalid model input names: {invalid_params}")
81
+ rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
80
82
 
81
83
  if rbln_batch_size is None:
82
84
  rbln_batch_size = 1
@@ -96,3 +98,7 @@ class RBLNBertModel(RBLNModel):
96
98
 
97
99
  rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
98
100
  return rbln_config
101
+
102
+
103
+ class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
104
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import logging
25
16
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import math
25
16
  from typing import List, Optional, Tuple
26
17
 
@@ -631,33 +622,21 @@ class DecoderOnlyAttention(nn.Module):
631
622
  if batch_size > 1 and self.phase == "prefill":
632
623
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
633
624
 
634
- # TODO(jongho): flash attn legacy. (clone)
635
- _seq_positions = seq_positions.clone().unsqueeze(1)
636
-
637
- _key_states = []
638
- _value_states = []
639
- _attn_outputs = []
640
- for b in range(batch_size):
641
- seq_position = _seq_positions[b][0]
642
- attn_output, key_state, value_state = self.attention(
643
- query_states[b].unsqueeze(0),
644
- key_states[b].unsqueeze(0),
645
- value_states[b].unsqueeze(0),
646
- attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
647
- past_key_state=past_key_values[self.layer_idx][0],
648
- past_value_state=past_key_values[self.layer_idx][1],
649
- batch_position=b if self.phase == "decode" else batch_position,
650
- seq_position=seq_position,
651
- scale=self.scale,
652
- )
653
- _key_states.append(key_state)
654
- _value_states.append(value_state)
655
- _attn_outputs.append(attn_output)
656
- key_states = torch.cat(_key_states, dim=0)
657
- value_states = torch.cat(_value_states, dim=0)
658
- attn_outputs = torch.cat(_attn_outputs, dim=0)
659
-
660
- attn_outputs = self.o_proj(attn_outputs)
625
+ attn_output, key_state, value_state = self.attention(
626
+ query_states,
627
+ key_states,
628
+ value_states,
629
+ attention_mask,
630
+ past_key_state=past_key_values[self.layer_idx][0],
631
+ past_value_state=past_key_values[self.layer_idx][1],
632
+ batch_position=None if self.phase == "decode" else batch_position,
633
+ seq_position=seq_positions,
634
+ scale=self.scale,
635
+ )
636
+ key_states = key_state
637
+ value_states = value_state
638
+
639
+ attn_outputs = self.o_proj(attn_output)
661
640
  past_key_values[self.layer_idx] = key_states, value_states
662
641
  return attn_outputs, past_key_values
663
642
 
@@ -703,8 +682,13 @@ class AttentionOp(nn.Module):
703
682
  value_state = value_state.unsqueeze(2)
704
683
  attn_mask = attn_mask.unsqueeze(2)
705
684
 
685
+ if self.phase == "decode":
686
+ batch_size = key_state.shape[0]
687
+ else:
688
+ batch_size = 1
689
+
706
690
  query_state = query_state.view(
707
- 1,
691
+ batch_size,
708
692
  self.num_key_value_heads,
709
693
  self.num_heads // self.num_key_value_heads,
710
694
  -1, # seq len
@@ -736,9 +720,9 @@ class AttentionOp(nn.Module):
736
720
  scale,
737
721
  )
738
722
 
739
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
723
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
740
724
  attn_output = attn_output.transpose(1, 2).contiguous()
741
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
725
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
742
726
 
743
727
  return attn_output, key_state.squeeze(2), value_state.squeeze(2)
744
728
 
@@ -867,31 +851,23 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
867
851
  if cos is not None and sin is not None:
868
852
  query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
869
853
 
870
- _key_states = []
871
- _value_states = []
872
- _attn_outputs = []
873
- for b in range(batch_size):
874
- seq_position = seq_positions[b][0] # FIXME: Remove take-take pattern matching
875
- attn_output, key_state, value_state = self.attention(
876
- query_states[b].unsqueeze(0),
877
- key_states[b].unsqueeze(0),
878
- value_states[b].unsqueeze(0),
879
- attention_mask[b].unsqueeze(0) if self.phase == "decode" else attention_mask,
880
- past_key_state=past_key_values[self.layer_idx][0],
881
- past_value_state=past_key_values[self.layer_idx][1],
882
- batch_position=b if self.phase == "decode" else batch_position,
883
- seq_position=seq_position,
884
- scale=self.scale,
885
- )
886
- _key_states.append(key_state)
887
- _value_states.append(value_state)
888
- _attn_outputs.append(attn_output)
889
- key_states = torch.cat(_key_states, dim=0)
890
- value_states = torch.cat(_value_states, dim=0)
891
- attn_outputs = torch.cat(_attn_outputs, dim=0)
892
-
893
- attn_outputs = self.o_proj(attn_outputs)
854
+ attn_output, key_state, value_state = self.attention(
855
+ query_states,
856
+ key_states,
857
+ value_states,
858
+ attention_mask,
859
+ past_key_state=past_key_values[self.layer_idx][0],
860
+ past_value_state=past_key_values[self.layer_idx][1],
861
+ batch_position=None if self.phase == "decode" else batch_position,
862
+ seq_position=seq_positions,
863
+ scale=self.scale,
864
+ )
865
+ key_states = key_state
866
+ value_states = value_state
867
+
868
+ attn_outputs = self.o_proj(attn_output)
894
869
  past_key_values[self.layer_idx] = key_states, value_states
870
+
895
871
  return attn_outputs, past_key_values
896
872
 
897
873
 
@@ -917,8 +893,13 @@ class FlashAttentionOp(AttentionOp):
917
893
  value_state = value_state.unsqueeze(2)
918
894
  attn_mask = attn_mask.unsqueeze(2)
919
895
 
896
+ if self.phase == "decode":
897
+ batch_size = key_state.shape[0]
898
+ else:
899
+ batch_size = 1
900
+
920
901
  query_state = query_state.view(
921
- 1,
902
+ batch_size,
922
903
  self.num_key_value_heads,
923
904
  self.num_heads // self.num_key_value_heads,
924
905
  -1, # seq len
@@ -952,8 +933,8 @@ class FlashAttentionOp(AttentionOp):
952
933
  )
953
934
 
954
935
  # reshape for removing repeat_kv
955
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
936
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
956
937
  attn_output = attn_output.transpose(1, 2).contiguous()
957
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
938
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
958
939
 
959
940
  return attn_output, key_state, value_state
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  from dataclasses import dataclass
26
17
  from pathlib import Path
@@ -218,14 +209,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
218
209
 
219
210
  @classmethod
220
211
  def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
212
+ logger.debug("Loading the LLM model to the CPU.") # TODO(jongho): Remove.
213
+
221
214
  rbln_kwargs = kwargs.get("rbln_kwargs", {})
222
215
  rbln_quantization = rbln_kwargs.get("quantization", None)
223
-
224
216
  if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
225
217
  model = cls.get_quantized_model(*args, **kwargs)
226
218
  else:
227
219
  model = super().get_pytorch_model(*args, **kwargs)
228
220
 
221
+ logger.debug("Loaded the LLM model to the CPU.")
229
222
  return model
230
223
 
231
224
  @classmethod
@@ -297,8 +290,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
297
290
  rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
298
291
  rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
299
292
  rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
293
+ rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
294
+
295
+ if rbln_prefill_chunk_size is None:
296
+ rbln_prefill_chunk_size = 128
297
+ elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
298
+ raise ValueError(
299
+ f"Invalid rbln_prefill_chunk_size: {rbln_prefill_chunk_size}. It must be a nonzero multiple of 64."
300
+ )
300
301
 
301
- prefill_chunk_size = 128
302
302
  if rbln_max_seq_len is None:
303
303
  rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
304
304
  model_config, "n_positions", None
@@ -369,7 +369,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
369
369
 
370
370
  prefill_input_info = get_input_info(
371
371
  batch_size=1,
372
- query_length=prefill_chunk_size,
372
+ query_length=rbln_prefill_chunk_size,
373
373
  use_inputs_embeds=rbln_use_inputs_embeds,
374
374
  hidden_size=hidden_size,
375
375
  )
@@ -393,7 +393,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
393
393
  {
394
394
  "max_seq_len": rbln_max_seq_len,
395
395
  "batch_size": rbln_batch_size,
396
- "prefill_chunk_size": prefill_chunk_size,
396
+ "prefill_chunk_size": rbln_prefill_chunk_size,
397
397
  "use_inputs_embeds": rbln_use_inputs_embeds,
398
398
  "kvcache_partition_len": rbln_kvcache_partition_len,
399
399
  "attn_impl": rbln_attn_impl,
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_dpt import RBLNDPTForDepthEstimation
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import logging
25
16
  from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import os
25
16
  from os import environ
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import TYPE_CHECKING
25
16
 
26
17
  import torch.nn as nn
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
 
25
16
  from transformers import AutoModelForCausalLM
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_gemma import RBLNGemmaForCausalLM
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import TYPE_CHECKING
25
16
 
26
17
  from ...models.decoderonly.decoderonly_architecture import (
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ....utils import logging
25
16
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
17
  from .gemma_architecture import GemmaWrapper
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_gpt2 import RBLNGPT2LMHeadModel