optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -58,23 +58,12 @@ class MidmLMHeadModelWrapper(torch.nn.Module):
58
58
  self.model = model.transformer
59
59
  self.lm_head = model.lm_head
60
60
  self.config = model.config
61
- self.head_dim = self.config.n_embd // self.config.n_head
62
- self.max_position_embeddings = (
63
- self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
64
- )
65
61
  self.max_seq_len = max_seq_len
66
- self.rotary_dim = int(
67
- model.config.hidden_size // model.config.num_attention_heads * model.config.rotary_percentage
68
- )
69
- self.rotary_emb = self._init_rope()
70
62
 
71
- def _init_rope(self):
72
- """Initializes the Rotary Position Embeddings."""
73
- rotary_emb = RotaryEmbedding(
74
- self.rotary_dim,
75
- max_position_embeddings=self.max_position_embeddings,
76
- )
77
- return rotary_emb
63
+ self.config.partial_rotary_factor = model.config.rotary_percentage
64
+ self.config.head_dim = self.config.n_embd // self.config.n_head
65
+ self.config.rope_theta = 10000
66
+ self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
78
67
 
79
68
  def forward(
80
69
  self,
@@ -21,11 +21,7 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from ....modeling_config import RBLNConfig
24
+ from ....utils import logging
29
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
30
26
  from .hf_hub_cached.modeling_midm import MidmLMHeadModel
31
27
  from .midm_architecture import (
@@ -33,11 +29,7 @@ from .midm_architecture import (
33
29
  )
34
30
 
35
31
 
36
- logger = logging.getLogger(__name__)
37
- if TYPE_CHECKING:
38
- from transformers import (
39
- PreTrainedModel,
40
- )
32
+ logger = logging.get_logger(__name__)
41
33
 
42
34
 
43
35
  class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
@@ -54,22 +46,10 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
54
46
 
55
47
  """
56
48
 
57
- @classmethod
58
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
59
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
60
- return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
61
-
62
- def __getattr__(self, __name: str) -> Any:
63
- """This is the key method to implement RBLN-Midm.
49
+ _decoder_wrapper_cls = MidmLMHeadModelWrapper
50
+ _original_cls = MidmLMHeadModel
64
51
 
65
- Returns:
66
- Any: Midm's corresponding method
67
- """
68
-
69
- def redirect(func):
70
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
71
-
72
- val = getattr(MidmLMHeadModel, __name)
73
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
74
- return redirect(val)
75
- return val
52
+ @classmethod
53
+ def from_pretrained(cls, *args, **kwargs):
54
+ kwargs.setdefault("trust_remote_code", True)
55
+ return super().from_pretrained(*args, **kwargs)
@@ -21,29 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import MistralForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .mistral_architecture import MistralForCausalLMWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
-
40
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
41
30
 
42
31
 
43
32
  class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
44
33
  """
45
34
  The Llama Model transformer with a language modeling head (linear layer) on top.
46
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
47
36
 
48
37
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
49
38
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -51,18 +40,4 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
51
40
  - compiling the resulting graph using the RBLN compiler.
52
41
  """
53
42
 
54
- @classmethod
55
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
57
- return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
58
-
59
- def __getattr__(self, __name: str) -> Any:
60
- def redirect(func):
61
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
62
-
63
- val = getattr(MistralForCausalLM, __name)
64
-
65
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
66
- return redirect(val)
67
-
68
- return val
43
+ _decoder_wrapper_cls = MistralForCausalLMWrapper
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import PhiForCausalLM
29
-
30
- from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
24
+ from ....utils import logging
25
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .phi_architecture import PhiWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Phi Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based PhiForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers PhiForCausalLM model into a RBLN transformer model by:
@@ -50,20 +40,4 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return PhiWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(PhiForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(
65
- inspect.signature(val).parameters
66
- ):
67
- return redirect(val)
68
-
69
- return val
43
+ _decoder_wrapper_cls = PhiWrapper
@@ -33,46 +33,12 @@ from transformers.modeling_outputs import (
33
33
  from ...cache_utils import RebelDynamicCache
34
34
  from ..decoderonly import (
35
35
  DecoderOnlyWrapper,
36
- DynamicNTKScalingRotaryEmbedding,
37
- LinearScalingRotaryEmbedding,
38
- RotaryEmbedding,
39
36
  apply_rotary_pos_emb,
40
37
  slice_and_unsqueeze_cos_sin,
41
38
  )
42
39
 
43
40
 
44
41
  class PhiWrapper(DecoderOnlyWrapper):
45
- def _init_rope(self):
46
- if self.rope_scaling is None:
47
- rotary_emb = RotaryEmbedding(
48
- int(self.config.partial_rotary_factor * self.head_dim),
49
- max_position_embeddings=self.max_position_embeddings,
50
- base=self.config.rope_theta,
51
- )
52
- else:
53
- scaling_type = self.rope_scaling["type"]
54
- scaling_factor = self.rope_scaling["factor"]
55
- if scaling_type == "linear":
56
- rotary_emb = LinearScalingRotaryEmbedding(
57
- int(self.config.partial_rotary_factor * self.head_dim),
58
- max_position_embeddings=self.max_position_embeddings,
59
- scaling_factor=scaling_factor,
60
- base=self.config.rope_theta,
61
- max_seq_len=self.max_seq_len,
62
- )
63
- elif scaling_type == "dynamic":
64
- rotary_emb = DynamicNTKScalingRotaryEmbedding(
65
- int(self.config.partial_rotary_factor * self.head_dim),
66
- max_position_embeddings=self.max_position_embeddings,
67
- scaling_factor=scaling_factor,
68
- base=self.config.rope_theta,
69
- max_seq_len=self.max_seq_len,
70
- )
71
- else:
72
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
73
-
74
- return rotary_emb
75
-
76
42
  def get_forward_dict(self):
77
43
  forward_dict = {}
78
44
  forward_dict.update(
@@ -86,6 +52,41 @@ class PhiWrapper(DecoderOnlyWrapper):
86
52
 
87
53
 
88
54
  class PhiAttention:
55
+ def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
56
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
57
+ key_state = key_state.unsqueeze(2)
58
+ value_state = value_state.unsqueeze(2)
59
+ attn_mask = attn_mask.unsqueeze(2)
60
+
61
+ query_state = query_state.view(
62
+ 1,
63
+ self.num_key_value_heads,
64
+ self.num_heads // self.num_key_value_heads,
65
+ -1,
66
+ self.head_dim,
67
+ )
68
+
69
+ key_state, value_state = past_key_value.update(key_state, value_state, self.layer_idx, batch_idx, is_prefill)
70
+
71
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
72
+ attn_weights = torch.matmul(
73
+ query_state.to(torch.float32),
74
+ key_state.to(torch.float32).transpose(3, 4),
75
+ ) / math.sqrt(self.head_dim)
76
+ attn_weights = attn_weights + attn_mask
77
+
78
+ # upcast attention to fp32
79
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_state.dtype)
80
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
81
+ attn_output = torch.matmul(attn_weights, value_state)
82
+
83
+ # reshape for removing repeat_kv
84
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
85
+ attn_output = attn_output.transpose(1, 2).contiguous()
86
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
87
+
88
+ return attn_output, key_state, value_state
89
+
89
90
  def forward(
90
91
  self,
91
92
  hidden_states: torch.Tensor,
@@ -95,7 +96,6 @@ class PhiAttention:
95
96
  output_attentions: bool = False,
96
97
  cos: Optional[torch.Tensor] = None,
97
98
  sin: Optional[torch.Tensor] = None,
98
- rotary_pos_emb=None,
99
99
  **kwargs,
100
100
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
101
101
  bsz, q_len, _ = hidden_states.size()
@@ -108,24 +108,18 @@ class PhiAttention:
108
108
  query_states = self.q_layernorm(query_states)
109
109
  key_states = self.k_layernorm(key_states)
110
110
 
111
- query_states = query_states.view(
112
- bsz, q_len, self.num_heads, self.head_dim
113
- ).transpose(1, 2)
114
- key_states = key_states.view(
115
- bsz, q_len, self.num_key_value_heads, self.head_dim
116
- ).transpose(1, 2)
117
- value_states = value_states.view(
118
- bsz, q_len, self.num_key_value_heads, self.head_dim
119
- ).transpose(1, 2)
111
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
112
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
113
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
120
114
 
121
115
  # Partial rotary embedding
122
116
  query_rot, query_pass = (
123
- query_states[..., : rotary_pos_emb.dim],
124
- query_states[..., rotary_pos_emb.dim :],
117
+ query_states[..., : self.rotary_ndims],
118
+ query_states[..., self.rotary_ndims :],
125
119
  )
126
120
  key_rot, key_pass = (
127
- key_states[..., : rotary_pos_emb.dim],
128
- key_states[..., rotary_pos_emb.dim :],
121
+ key_states[..., : self.rotary_ndims],
122
+ key_states[..., self.rotary_ndims :],
129
123
  )
130
124
 
131
125
  # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
@@ -135,113 +129,38 @@ class PhiAttention:
135
129
  query_states = torch.cat((query_rot, query_pass), dim=-1)
136
130
  key_states = torch.cat((key_rot, key_pass), dim=-1)
137
131
 
138
- # Decoder
139
- if (batch_index is None or batch_index == -1) and bsz > 1:
140
- all_key_states = []
141
- all_value_states = []
142
- all_attn_output = []
143
-
132
+ # Decoder (bsz > 1)
133
+ if bsz > 1:
134
+ iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
144
135
  for b in range(bsz):
145
- query_state = query_states[b].unsqueeze(0)
146
- attn_mask = attention_mask[b].unsqueeze(0)
147
- key_state = key_states[b].unsqueeze(0)
148
- value_state = value_states[b].unsqueeze(0)
149
-
150
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
151
- key_state = key_state.unsqueeze(2)
152
- value_state = value_state.unsqueeze(2)
153
- attn_mask = attn_mask.unsqueeze(2)
154
-
155
- query_state = query_state.view(
156
- 1,
157
- self.num_key_value_heads,
158
- self.num_heads // self.num_key_value_heads,
159
- q_len,
160
- self.head_dim,
136
+ attn_output, key_state, value_state = PhiAttention._attn(
137
+ self,
138
+ query_states[b].unsqueeze(0),
139
+ key_states[b].unsqueeze(0),
140
+ value_states[b].unsqueeze(0),
141
+ attention_mask[b].unsqueeze(0),
142
+ past_key_value,
143
+ batch_idx=b,
144
+ is_prefill=False,
161
145
  )
162
-
163
- key_state, value_state = past_key_value.update(
164
- key_state,
165
- value_state,
166
- self.layer_idx,
167
- b,
168
- )
169
-
170
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
171
- attn_weights = torch.matmul(
172
- query_state.to(torch.float32),
173
- key_state.to(torch.float32).transpose(3, 4),
174
- ) / math.sqrt(self.head_dim)
175
- attn_weights = attn_weights + attn_mask
176
-
177
- # upcast attention to fp32
178
- attn_weights = nn.functional.softmax(
179
- attn_weights, dim=-1, dtype=torch.float32
180
- ).to(query_states.dtype)
181
- attn_weights = nn.functional.dropout(
182
- attn_weights, p=self.attention_dropout, training=self.training
183
- )
184
- attn_output = torch.matmul(attn_weights, value_state)
185
-
186
- # reshape for removing repeat_kv
187
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
188
- attn_output = attn_output.transpose(1, 2).contiguous()
189
- attn_output = attn_output.reshape(
190
- 1, q_len, self.num_heads * self.head_dim
191
- )
192
-
193
- all_key_states.append(key_state)
194
- all_value_states.append(value_state)
195
- all_attn_output.append(attn_output)
196
-
197
- key_states = torch.cat(all_key_states, dim=0)
198
- value_states = torch.cat(all_value_states, dim=0)
199
- attn_output = torch.cat(all_attn_output, dim=0)
146
+ iterate_results["key_states"].append(key_state)
147
+ iterate_results["value_states"].append(value_state)
148
+ iterate_results["attn_output"].append(attn_output)
149
+
150
+ key_states = torch.cat(iterate_results["key_states"], dim=0)
151
+ value_states = torch.cat(iterate_results["value_states"], dim=0)
152
+ attn_output = torch.cat(iterate_results["attn_output"], dim=0)
153
+ # Prefill & Decoder (bsz == 1)
200
154
  else:
201
- if batch_index is None or batch_index == -1:
202
- batch_index = 0
203
-
204
- # reshape for removing repeat_kv
205
- key_states = key_states.unsqueeze(2)
206
- value_states = value_states.unsqueeze(2)
207
- attention_mask = attention_mask.unsqueeze(2)
208
- query_states = query_states.view(
209
- 1,
210
- self.num_key_value_heads,
211
- self.num_heads // self.num_key_value_heads,
212
- q_len,
213
- self.head_dim,
214
- )
215
-
216
- key_states, value_states = past_key_value.update(
155
+ attn_output, key_states, value_states = PhiAttention._attn(
156
+ self,
157
+ query_states,
217
158
  key_states,
218
159
  value_states,
219
- self.layer_idx,
220
- batch_index,
221
- read_first_step=True,
222
- )
223
-
224
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
225
- attn_weights = torch.matmul(
226
- query_states.to(torch.float32),
227
- key_states.to(torch.float32).transpose(3, 4),
228
- ) / math.sqrt(self.head_dim)
229
- attn_weights = attn_weights + attention_mask
230
-
231
- # upcast attention to fp32
232
- attn_weights = torch.nn.functional.softmax(
233
- attn_weights, dim=-1, dtype=torch.float32
234
- ).to(value_states.dtype)
235
- attn_weights = torch.nn.functional.dropout(
236
- attn_weights, p=self.attention_dropout, training=self.training
237
- )
238
- attn_output = torch.matmul(attn_weights, value_states)
239
-
240
- # reshape for removing repeat_kv
241
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
242
- attn_output = attn_output.transpose(1, 2).contiguous()
243
- attn_output = attn_output.reshape(
244
- bsz, q_len, self.num_heads * self.head_dim
160
+ attention_mask,
161
+ past_key_value,
162
+ batch_idx=batch_index,
163
+ is_prefill=True,
245
164
  )
246
165
 
247
166
  attn_output = self.dense(attn_output)
@@ -265,12 +184,9 @@ class PhiDecoderLayer:
265
184
  batch_ids: Optional[torch.LongTensor] = None,
266
185
  cos: Optional[torch.Tensor] = None,
267
186
  sin: Optional[torch.Tensor] = None,
268
- rotary_pos_emb=None,
269
187
  forward_dict: Optional[Dict[str, classmethod]] = None,
270
188
  **kwargs,
271
- ) -> Tuple[
272
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
273
- ]:
189
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
274
190
  """
275
191
  Args:
276
192
  hidden_states (`torch.FloatTensor`):
@@ -294,9 +210,7 @@ class PhiDecoderLayer:
294
210
  hidden_states = self.input_layernorm(hidden_states)
295
211
 
296
212
  # Self Attention
297
- attn_outputs, self_attn_weights, key_states, value_states = forward_dict[
298
- "decoder_layer"
299
- ](
213
+ attn_outputs, self_attn_weights, key_states, value_states = forward_dict["decoder_layer"](
300
214
  self.self_attn,
301
215
  hidden_states=hidden_states,
302
216
  attention_mask=attention_mask,
@@ -307,7 +221,6 @@ class PhiDecoderLayer:
307
221
  use_cache=use_cache,
308
222
  cos=cos,
309
223
  sin=sin,
310
- rotary_pos_emb=rotary_pos_emb,
311
224
  **kwargs,
312
225
  )
313
226
  past_key_value.assign(key_states, value_states, layer_idx)
@@ -339,6 +252,8 @@ class PhiModel:
339
252
  use_cache: Optional[bool] = True,
340
253
  output_attentions: Optional[bool] = False,
341
254
  output_hidden_states: Optional[bool] = False,
255
+ cache_pos_for_partitions: Optional[torch.Tensor] = None,
256
+ kvcache_partition_size: Optional[torch.Tensor] = None,
342
257
  forward_dict: Optional[Dict[str, classmethod]] = None,
343
258
  rotary_pos_emb=None,
344
259
  ) -> BaseModelOutputWithPast:
@@ -378,7 +293,8 @@ class PhiModel:
378
293
  batch_ids=batch_ids,
379
294
  cos=cos,
380
295
  sin=sin,
381
- rotary_pos_emb=rotary_pos_emb,
296
+ cache_pos_for_partitions=cache_pos_for_partitions,
297
+ kvcache_partition_size=kvcache_partition_size,
382
298
  forward_dict=forward_dict,
383
299
  )
384
300
 
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_qwen2 import RBLNQwen2ForCausalLM
@@ -0,0 +1,43 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from ....utils import logging
25
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
+ from .qwen2_architecture import QWEN2Wrapper
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
33
+ """
34
+ The Llama Model transformer with a language modeling head (linear layer) on top.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
36
+
37
+ A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
38
+ It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
39
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
40
+ - compiling the resulting graph using the RBLN compiler.
41
+ """
42
+
43
+ _decoder_wrapper_cls = QWEN2Wrapper
@@ -0,0 +1,29 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+
25
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
+
27
+
28
+ class QWEN2Wrapper(DecoderOnlyWrapper):
29
+ pass
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_seq2seq import RBLNModelForSeq2SeqLM