optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +22 -12
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
- optimum/rbln/diffusers/pipelines/__init__.py +22 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +164 -758
- optimum/rbln/modeling_diffusers.py +51 -122
- optimum/rbln/transformers/__init__.py +0 -2
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
- optimum/rbln/utils/decorator_utils.py +10 -6
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +15 -1
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +1 -1
- optimum/rbln/utils/submodule.py +114 -0
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,18 +21,24 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import TYPE_CHECKING, Tuple
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
29
28
|
|
30
|
-
from ....transformers.models.decoderonly.decoderonly_architecture import
|
31
|
-
|
32
|
-
|
33
|
-
|
29
|
+
from ....transformers.models.decoderonly.decoderonly_architecture import rotate_half
|
30
|
+
from ..decoderonly.decoderonly_architecture import (
|
31
|
+
DecoderOnlyAttention,
|
32
|
+
DecoderOnlyForCausalLM,
|
33
|
+
DecoderOnlyLayer,
|
34
|
+
DecoderOnlyModel,
|
35
|
+
DecoderOnlyWrapper,
|
36
|
+
apply_rotary_pos_emb_partial,
|
34
37
|
)
|
35
|
-
|
38
|
+
|
39
|
+
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from transformers import PreTrainedModel as MidmLMHeadModel
|
36
42
|
|
37
43
|
|
38
44
|
def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
|
@@ -50,253 +56,93 @@ def apply_rotary_pos_emb(q, k, cos, sin):
|
|
50
56
|
return q_embed, k_embed
|
51
57
|
|
52
58
|
|
53
|
-
class MidmLMHeadModelWrapper(
|
54
|
-
|
55
|
-
|
56
|
-
def __init__(self, model, max_seq_len):
|
57
|
-
super().__init__()
|
58
|
-
self.model = model.transformer
|
59
|
-
self.lm_head = model.lm_head
|
60
|
-
self.config = model.config
|
61
|
-
self.max_seq_len = max_seq_len
|
62
|
-
|
63
|
-
self.config.partial_rotary_factor = model.config.rotary_percentage
|
64
|
-
self.config.head_dim = self.config.n_embd // self.config.n_head
|
59
|
+
class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
60
|
+
def get_rotary_emb(self, max_seq_len):
|
65
61
|
self.config.rope_theta = 10000
|
66
|
-
self.
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
)
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
input_ids=input_ids,
|
92
|
-
past_key_values=past_key_values,
|
93
|
-
attention_mask=attention_mask,
|
94
|
-
position_ids=cache_position,
|
95
|
-
rotary_pos_emb=self.rotary_emb,
|
96
|
-
batch_ids=rbln_batch_position,
|
97
|
-
)
|
98
|
-
|
99
|
-
hidden_states = outputs[0]
|
100
|
-
if batch_position >= 0:
|
101
|
-
hidden_states = hidden_states[:, query_idx].unsqueeze(1)
|
102
|
-
|
103
|
-
logits = self.lm_head(hidden_states)
|
104
|
-
output = (logits,) + outputs[1:]
|
105
|
-
|
106
|
-
return output, batch_position + query_idx
|
107
|
-
|
108
|
-
|
109
|
-
def layernorm1p(module, input):
|
110
|
-
"""Applies Layer Normalization with a slight modification on the weights."""
|
111
|
-
return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
|
112
|
-
|
113
|
-
|
114
|
-
class _MidmAttention:
|
115
|
-
"""Custom implementation of the MidmAttention class with specific modifications."""
|
116
|
-
|
117
|
-
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
118
|
-
"""Computes the attention weights and output."""
|
119
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
120
|
-
|
121
|
-
if self.scale_attn_weights:
|
122
|
-
attn_weights = attn_weights / torch.full(
|
123
|
-
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
62
|
+
self.config.head_dim = self.config.n_embd // self.config.n_head
|
63
|
+
self.config.partial_rotary_factor = self.config.rotary_percentage
|
64
|
+
return super().get_rotary_emb(max_seq_len=max_seq_len)
|
65
|
+
|
66
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel"):
|
67
|
+
if self.attn_impl != "eager":
|
68
|
+
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
69
|
+
new_layers = []
|
70
|
+
for layer in causal_lm.transformer.h:
|
71
|
+
new_self_attn = MidmAttention(layer.attn)
|
72
|
+
new_layer = MidmLayer(layer, new_self_attn)
|
73
|
+
new_layers.append(new_layer)
|
74
|
+
new_model = MidmModel(causal_lm.transformer, new_layers)
|
75
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
76
|
+
return new_causal_lm
|
77
|
+
|
78
|
+
|
79
|
+
class MidmModel(DecoderOnlyModel):
|
80
|
+
mask_fmin = -10000.0
|
81
|
+
|
82
|
+
def get_layernorm1p(self, module: nn.LayerNorm):
|
83
|
+
def layernorm1p(input: torch.Tensor):
|
84
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
85
|
+
return torch.nn.functional.layer_norm(
|
86
|
+
input, module.normalized_shape, module.weight + 1, module.bias, module.eps
|
124
87
|
)
|
125
88
|
|
126
|
-
|
127
|
-
attn_weights = attn_weights / float(self.layer_idx + 1)
|
89
|
+
return layernorm1p
|
128
90
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
136
|
-
attn_weights = attn_weights.type(value.dtype)
|
137
|
-
|
138
|
-
if head_mask is not None:
|
139
|
-
attn_weights = attn_weights * head_mask
|
140
|
-
|
141
|
-
attn_output = torch.matmul(attn_weights, value)
|
142
|
-
return attn_output, attn_weights
|
143
|
-
|
144
|
-
def forward(
|
145
|
-
self,
|
146
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
147
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
148
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
149
|
-
batch_index: Optional[int] = None,
|
150
|
-
cos: Optional[torch.Tensor] = None,
|
151
|
-
sin: Optional[torch.Tensor] = None,
|
152
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
153
|
-
"""Defines the forward pass for the attention mechanism."""
|
154
|
-
bsz, q_len, _ = hidden_states.size()
|
155
|
-
|
156
|
-
querys, keys, values = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
157
|
-
|
158
|
-
querys = self._split_heads(querys, self.num_heads, self.head_dim).contiguous()
|
159
|
-
keys = self._split_heads(keys, self.num_heads, self.head_dim).contiguous()
|
160
|
-
values = self._split_heads(values, self.num_heads, self.head_dim).contiguous()
|
161
|
-
|
162
|
-
querys, keys = apply_rotary_pos_emb(querys, keys, cos, sin)
|
163
|
-
|
164
|
-
# Decoder
|
165
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
166
|
-
all_key_states = []
|
167
|
-
all_value_states = []
|
168
|
-
all_attn_output = []
|
91
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
92
|
+
if self._original_mod.use_layernorm1p:
|
93
|
+
return self.get_layernorm1p(self._original_mod.ln_f)
|
94
|
+
else:
|
95
|
+
return self._original_mod.ln_f
|
169
96
|
|
170
|
-
|
171
|
-
|
172
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
173
|
-
key = keys[b].unsqueeze(0)
|
174
|
-
value = values[b].unsqueeze(0)
|
97
|
+
def get_embedding(self) -> nn.Embedding:
|
98
|
+
return self._original_mod.wte
|
175
99
|
|
176
|
-
|
177
|
-
|
178
|
-
value,
|
179
|
-
self.layer_idx,
|
180
|
-
b,
|
181
|
-
)
|
100
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
101
|
+
return self._original_mod.wpe
|
182
102
|
|
183
|
-
attn_output, _ = _MidmAttention._attn(self, query, key, value, attn_mask)
|
184
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
185
103
|
|
186
|
-
|
187
|
-
|
188
|
-
|
104
|
+
class MidmLayer(DecoderOnlyLayer):
|
105
|
+
def get_layernorm1p(self, module: nn.LayerNorm):
|
106
|
+
def layernorm1p(input: torch.Tensor):
|
107
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
108
|
+
return torch.nn.functional.layer_norm(
|
109
|
+
input, module.normalized_shape, module.weight + 1, module.bias, module.eps
|
110
|
+
)
|
189
111
|
|
190
|
-
|
191
|
-
values = torch.cat(all_value_states, dim=0)
|
192
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
112
|
+
return layernorm1p
|
193
113
|
|
114
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
115
|
+
if self._original_mod.use_layernorm1p:
|
116
|
+
return self.get_layernorm1p(self._original_mod.ln_1)
|
194
117
|
else:
|
195
|
-
|
196
|
-
batch_index = 0
|
197
|
-
|
198
|
-
keys, values = past_key_value.update(
|
199
|
-
keys,
|
200
|
-
values,
|
201
|
-
self.layer_idx,
|
202
|
-
batch_index,
|
203
|
-
read_first_step=True,
|
204
|
-
)
|
118
|
+
return self._original_mod.ln_1
|
205
119
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
attn_output = self.c_proj(attn_output)
|
210
|
-
return attn_output, keys, values
|
211
|
-
|
212
|
-
|
213
|
-
class _MidmBlock:
|
214
|
-
"""Custom implementation of the MidmBlock class with specific modifications."""
|
215
|
-
|
216
|
-
def forward(
|
217
|
-
self,
|
218
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
219
|
-
layer_idx: int,
|
220
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
221
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
222
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
223
|
-
cos: Optional[torch.Tensor] = None,
|
224
|
-
sin: Optional[torch.Tensor] = None,
|
225
|
-
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
226
|
-
"""Defines the forward pass for the block."""
|
227
|
-
residual = hidden_states
|
228
|
-
if self.use_layernorm1p:
|
229
|
-
hidden_states = layernorm1p(self.ln_1, hidden_states)
|
120
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
121
|
+
if self._original_mod.use_layernorm1p:
|
122
|
+
return self.get_layernorm1p(self._original_mod.ln_2)
|
230
123
|
else:
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
124
|
+
return self._original_mod.ln_2
|
125
|
+
|
126
|
+
|
127
|
+
class MidmAttention(DecoderOnlyAttention):
|
128
|
+
def __post_init__(self):
|
129
|
+
self.c_attn = self._original_mod.c_attn
|
130
|
+
self.o_proj = self._original_mod.c_proj
|
131
|
+
self.split_size = self._original_mod.split_size
|
132
|
+
self.num_key_value_heads = self._original_mod.num_heads
|
133
|
+
|
134
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
135
|
+
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
136
|
+
return query_states, key_states, value_states
|
137
|
+
|
138
|
+
def rbln_attention(self, *args, **kwargs):
|
139
|
+
return super().rbln_attention(
|
140
|
+
*args,
|
141
|
+
**kwargs,
|
142
|
+
layer_idx=self.layer_idx,
|
143
|
+
scale_attn_weights=self._original_mod.scale_attn_weights,
|
144
|
+
scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
|
241
145
|
)
|
242
|
-
past_key_value.assign(k, v, layer_idx)
|
243
|
-
|
244
|
-
hidden_states = hidden_states + residual
|
245
146
|
|
246
|
-
|
247
|
-
|
248
|
-
hidden_states = layernorm1p(self.ln_2, hidden_states)
|
249
|
-
else:
|
250
|
-
hidden_states = self.ln_2(hidden_states)
|
251
|
-
|
252
|
-
feed_forward_hidden_states = self.mlp(hidden_states)
|
253
|
-
hidden_states = residual + feed_forward_hidden_states
|
254
|
-
|
255
|
-
return hidden_states, past_key_value
|
256
|
-
|
257
|
-
|
258
|
-
class _MidmModel:
|
259
|
-
"""Custom implementation of the MidmModel class with specific modifications."""
|
260
|
-
|
261
|
-
def forward(
|
262
|
-
self,
|
263
|
-
input_ids: Optional[torch.LongTensor] = None,
|
264
|
-
past_key_values: Optional[RebelDynamicCache_4D] = None,
|
265
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
266
|
-
position_ids: Optional[torch.LongTensor] = None,
|
267
|
-
rotary_pos_emb=None,
|
268
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
269
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
270
|
-
"""Defines the forward pass for the model."""
|
271
|
-
input_shape = input_ids.size()
|
272
|
-
|
273
|
-
attention_mask = (1.0 - attention_mask) * -10000.0
|
274
|
-
|
275
|
-
inputs_embeds = self.wte(input_ids)
|
276
|
-
|
277
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
278
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
279
|
-
hidden_states = inputs_embeds
|
280
|
-
|
281
|
-
for layer_idx, (block, _) in enumerate(zip(self.h, past_key_values)):
|
282
|
-
hidden_states, updated_cache = _MidmBlock.forward(
|
283
|
-
block,
|
284
|
-
hidden_states,
|
285
|
-
layer_idx,
|
286
|
-
attention_mask=attention_mask,
|
287
|
-
past_key_value=past_key_values,
|
288
|
-
batch_ids=batch_ids,
|
289
|
-
cos=cos,
|
290
|
-
sin=sin,
|
291
|
-
)
|
292
|
-
|
293
|
-
hidden_states = layernorm1p(self.ln_f, hidden_states)
|
294
|
-
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
295
|
-
hidden_states = hidden_states.view(output_shape)
|
296
|
-
|
297
|
-
next_cache = updated_cache.to_legacy_cache()
|
298
|
-
|
299
|
-
return BaseModelOutputWithPast(
|
300
|
-
last_hidden_state=hidden_states,
|
301
|
-
past_key_values=next_cache,
|
302
|
-
)
|
147
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
148
|
+
return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
|
@@ -21,12 +21,12 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
|
25
|
+
from transformers import AutoModelForCausalLM
|
26
|
+
|
24
27
|
from ....utils import logging
|
25
|
-
from
|
26
|
-
from .
|
27
|
-
from .midm_architecture import (
|
28
|
-
MidmLMHeadModelWrapper,
|
29
|
-
)
|
28
|
+
from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
|
29
|
+
from .midm_architecture import MidmLMHeadModelWrapper
|
30
30
|
|
31
31
|
|
32
32
|
logger = logging.get_logger(__name__)
|
@@ -47,7 +47,7 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
47
47
|
"""
|
48
48
|
|
49
49
|
_decoder_wrapper_cls = MidmLMHeadModelWrapper
|
50
|
-
|
50
|
+
_hf_class = AutoModelForCausalLM
|
51
51
|
|
52
52
|
@classmethod
|
53
53
|
def from_pretrained(cls, *args, **kwargs):
|