optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +5 -1
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
- optimum/rbln/diffusers/models/controlnet.py +36 -56
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
- optimum/rbln/modeling_base.py +12 -5
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +7 -0
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +2 -2
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Optional, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from transformers import PretrainedConfig
|
6
|
+
|
7
|
+
|
8
|
+
def _compute_default_rope_parameters(
|
9
|
+
config: Optional[PretrainedConfig] = None,
|
10
|
+
seq_len: Optional[int] = None,
|
11
|
+
) -> Tuple["torch.Tensor", float]:
|
12
|
+
"""
|
13
|
+
Computes the inverse frequencies according to the original RoPE implementation
|
14
|
+
Args:
|
15
|
+
config ([`~transformers.PretrainedConfig`]):
|
16
|
+
The model configuration.
|
17
|
+
seq_len (`int`, *optional*):
|
18
|
+
The current sequence length. Unused for this type of RoPE.
|
19
|
+
Returns:
|
20
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
21
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
22
|
+
"""
|
23
|
+
|
24
|
+
base = config.rope_theta
|
25
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
26
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
27
|
+
dim = int(head_dim * partial_rotary_factor)
|
28
|
+
|
29
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
30
|
+
|
31
|
+
# Compute the inverse frequencies
|
32
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
33
|
+
return inv_freq, attention_factor
|
34
|
+
|
35
|
+
|
36
|
+
def _compute_linear_scaling_rope_parameters(
|
37
|
+
config: Optional[PretrainedConfig] = None,
|
38
|
+
seq_len: Optional[int] = None,
|
39
|
+
) -> Tuple["torch.Tensor", float]:
|
40
|
+
"""
|
41
|
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
42
|
+
Args:
|
43
|
+
config ([`~transformers.PretrainedConfig`]):
|
44
|
+
The model configuration.
|
45
|
+
seq_len (`int`, *optional*):
|
46
|
+
The current sequence length. Unused for this type of RoPE.
|
47
|
+
Returns:
|
48
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
49
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
50
|
+
"""
|
51
|
+
|
52
|
+
factor = config.rope_scaling["factor"]
|
53
|
+
|
54
|
+
# Gets the default RoPE parameters
|
55
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
56
|
+
|
57
|
+
# Then applies linear scaling to the frequencies.
|
58
|
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
59
|
+
# applying scaling to the inverse frequencies is equivalent.
|
60
|
+
inv_freq /= factor
|
61
|
+
return inv_freq, attention_factor
|
62
|
+
|
63
|
+
|
64
|
+
def _compute_dynamic_ntk_parameters(
|
65
|
+
config: Optional[PretrainedConfig] = None,
|
66
|
+
seq_len: Optional[int] = None,
|
67
|
+
) -> Tuple["torch.Tensor", float]:
|
68
|
+
"""
|
69
|
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
70
|
+
Args:
|
71
|
+
config ([`~transformers.PretrainedConfig`]):
|
72
|
+
The model configuration.
|
73
|
+
seq_len (`int`, *optional*):
|
74
|
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
75
|
+
Returns:
|
76
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
77
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
78
|
+
"""
|
79
|
+
|
80
|
+
base = config.rope_theta
|
81
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
82
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
83
|
+
dim = int(head_dim * partial_rotary_factor)
|
84
|
+
max_position_embeddings = config.max_position_embeddings
|
85
|
+
factor = config.rope_scaling["factor"]
|
86
|
+
|
87
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
88
|
+
|
89
|
+
# Process with chunk_size to reduce precesion error
|
90
|
+
chunk_size = 4096
|
91
|
+
chunks = (seq_len + chunk_size - 1) // chunk_size
|
92
|
+
|
93
|
+
inv_freq_list = []
|
94
|
+
for i in range(chunks):
|
95
|
+
start = i * chunk_size
|
96
|
+
end = min((i + 1) * chunk_size, seq_len)
|
97
|
+
|
98
|
+
seq_lens = torch.arange(start, end, dtype=torch.float32).view(-1, 1) + 1.0
|
99
|
+
seq_lens = torch.where(seq_lens > max_position_embeddings, seq_lens, max_position_embeddings)
|
100
|
+
|
101
|
+
# Compute the inverse frequencies for each chunk
|
102
|
+
scaled_base = base * ((factor * seq_lens / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
103
|
+
inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
104
|
+
|
105
|
+
inv_freq_list.append(inv_freq)
|
106
|
+
|
107
|
+
final_inv_freq = torch.cat(inv_freq_list, dim=0)
|
108
|
+
|
109
|
+
return final_inv_freq, attention_factor
|
110
|
+
|
111
|
+
|
112
|
+
def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
|
113
|
+
"""
|
114
|
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
115
|
+
[original paper](https://arxiv.org/abs/2309.00071)
|
116
|
+
Args:
|
117
|
+
config ([`~transformers.PretrainedConfig`]):
|
118
|
+
The model configuration.
|
119
|
+
seq_len (`int`, *optional*):
|
120
|
+
The current sequence length. Unused for this type of RoPE.
|
121
|
+
Returns:
|
122
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
123
|
+
post-processing scaling factor applied to the computed cos/sin.
|
124
|
+
"""
|
125
|
+
|
126
|
+
base = config.rope_theta
|
127
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
128
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
129
|
+
dim = int(head_dim * partial_rotary_factor)
|
130
|
+
max_position_embeddings = config.max_position_embeddings
|
131
|
+
factor = config.rope_scaling["factor"]
|
132
|
+
|
133
|
+
# Sets the attention factor as suggested in the paper
|
134
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
135
|
+
if attention_factor is None:
|
136
|
+
attention_factor = 0.1 * math.log(factor) + 1.0
|
137
|
+
|
138
|
+
# Optional config options
|
139
|
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
140
|
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
141
|
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
142
|
+
|
143
|
+
# Compute the inverse frequencies
|
144
|
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
145
|
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
146
|
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
147
|
+
|
148
|
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
149
|
+
"""Find dimension range bounds based on rotations"""
|
150
|
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
151
|
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
152
|
+
return max(low, 0), min(high, dim - 1)
|
153
|
+
|
154
|
+
def linear_ramp_factor(min, max, dim):
|
155
|
+
if min == max:
|
156
|
+
max += 0.001 # Prevent singularity
|
157
|
+
|
158
|
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
159
|
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
160
|
+
return ramp_func
|
161
|
+
|
162
|
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
163
|
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
164
|
+
pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
|
165
|
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
166
|
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
167
|
+
|
168
|
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
169
|
+
|
170
|
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
171
|
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
|
172
|
+
inv_freq = (
|
173
|
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
174
|
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
175
|
+
)
|
176
|
+
|
177
|
+
return inv_freq, attention_factor
|
178
|
+
|
179
|
+
|
180
|
+
def _compute_longrope_parameters(
|
181
|
+
config: PretrainedConfig, seq_len: Optional[int] = None
|
182
|
+
) -> Tuple["torch.Tensor", float]:
|
183
|
+
"""
|
184
|
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
185
|
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
186
|
+
Args:
|
187
|
+
config ([`~transformers.PretrainedConfig`]):
|
188
|
+
The model configuration.
|
189
|
+
seq_len (`int`, *optional*):
|
190
|
+
The current sequence length. Unused for this type of RoPE.
|
191
|
+
Returns:
|
192
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
193
|
+
post-processing scaling factor applied to the computed cos/sin.
|
194
|
+
"""
|
195
|
+
|
196
|
+
base = config.rope_theta
|
197
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
198
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
199
|
+
dim = int(head_dim * partial_rotary_factor)
|
200
|
+
long_factor = config.rope_scaling["long_factor"]
|
201
|
+
short_factor = config.rope_scaling["short_factor"]
|
202
|
+
factor = config.rope_scaling.get("factor")
|
203
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
204
|
+
|
205
|
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
206
|
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
207
|
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
208
|
+
if hasattr(config, "original_max_position_embeddings"):
|
209
|
+
max_position_embeddings = config.original_max_position_embeddings
|
210
|
+
expanded_max_position_embeddings = config.max_position_embeddings
|
211
|
+
factor = expanded_max_position_embeddings / max_position_embeddings
|
212
|
+
else:
|
213
|
+
max_position_embeddings = config.max_position_embeddings
|
214
|
+
expanded_max_position_embeddings = max_position_embeddings * factor
|
215
|
+
|
216
|
+
# Sets the attention factor as suggested in the paper
|
217
|
+
if attention_factor is None:
|
218
|
+
if factor <= 1.0:
|
219
|
+
attention_factor = 1.0
|
220
|
+
else:
|
221
|
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
222
|
+
|
223
|
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
224
|
+
if expanded_max_position_embeddings > max_position_embeddings:
|
225
|
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32)
|
226
|
+
else:
|
227
|
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32)
|
228
|
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
|
229
|
+
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
230
|
+
|
231
|
+
return inv_freq, attention_factor
|
232
|
+
|
233
|
+
|
234
|
+
def _compute_llama3_parameters(
|
235
|
+
config: PretrainedConfig, seq_len: Optional[int] = None
|
236
|
+
) -> Tuple["torch.Tensor", float]:
|
237
|
+
"""
|
238
|
+
Computes the inverse frequencies for llama 3.1.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
config ([`~transformers.PretrainedConfig`]):
|
242
|
+
The model configuration.
|
243
|
+
seq_len (`int`, *optional*):
|
244
|
+
The current sequence length. Unused for this type of RoPE.
|
245
|
+
Returns:
|
246
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
247
|
+
post-processing scaling factor applied to the computed cos/sin.
|
248
|
+
"""
|
249
|
+
# Gets the default RoPE parameters
|
250
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
251
|
+
|
252
|
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
253
|
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
254
|
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
255
|
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
256
|
+
|
257
|
+
low_freq_wavelen = old_context_len / low_freq_factor
|
258
|
+
high_freq_wavelen = old_context_len / high_freq_factor
|
259
|
+
|
260
|
+
wavelen = 2 * math.pi / inv_freq
|
261
|
+
# wavelen < high_freq_wavelen: do nothing
|
262
|
+
# wavelen > low_freq_wavelen: divide by factor
|
263
|
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
264
|
+
# otherwise: interpolate between the two, using a smooth factor
|
265
|
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
266
|
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
267
|
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
268
|
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
269
|
+
|
270
|
+
return inv_freq_llama, attention_factor
|
271
|
+
|
272
|
+
|
273
|
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
274
|
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
275
|
+
# parameterizations, as long as the callable has the same signature.
|
276
|
+
ROPE_INIT_FUNCTIONS = {
|
277
|
+
"default": _compute_default_rope_parameters,
|
278
|
+
"linear": _compute_linear_scaling_rope_parameters,
|
279
|
+
"dynamic": _compute_dynamic_ntk_parameters,
|
280
|
+
"yarn": _compute_yarn_parameters,
|
281
|
+
"longrope": _compute_longrope_parameters,
|
282
|
+
"llama3": _compute_llama3_parameters,
|
283
|
+
}
|
@@ -21,35 +21,84 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
from typing import TYPE_CHECKING
|
24
25
|
|
25
|
-
from .
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
26
|
+
from transformers.utils import _LazyModule
|
27
|
+
|
28
|
+
|
29
|
+
_import_structure = {
|
30
|
+
"auto": [
|
31
|
+
"RBLNAutoModel",
|
32
|
+
"RBLNAutoModelForAudioClassification",
|
33
|
+
"RBLNAutoModelForCausalLM",
|
34
|
+
"RBLNAutoModelForCTC",
|
35
|
+
"RBLNAutoModelForDepthEstimation",
|
36
|
+
"RBLNAutoModelForImageClassification",
|
37
|
+
"RBLNAutoModelForMaskedLM",
|
38
|
+
"RBLNAutoModelForQuestionAnswering",
|
39
|
+
"RBLNAutoModelForSeq2SeqLM",
|
40
|
+
"RBLNAutoModelForSequenceClassification",
|
41
|
+
"RBLNAutoModelForSpeechSeq2Seq",
|
42
|
+
"RBLNAutoModelForVision2Seq",
|
43
|
+
],
|
44
|
+
"bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
|
45
|
+
"bert": ["RBLNBertModel"],
|
46
|
+
"clip": ["RBLNCLIPTextModel", "RBLNCLIPTextModelWithProjection", "RBLNCLIPVisionModel"],
|
47
|
+
"dpt": ["RBLNDPTForDepthEstimation"],
|
48
|
+
"exaone": ["RBLNExaoneForCausalLM"],
|
49
|
+
"gemma": ["RBLNGemmaForCausalLM"],
|
50
|
+
"gpt2": ["RBLNGPT2LMHeadModel"],
|
51
|
+
"llama": ["RBLNLlamaForCausalLM"],
|
52
|
+
"llava_next": ["RBLNLlavaNextForConditionalGeneration"],
|
53
|
+
"midm": ["RBLNMidmLMHeadModel"],
|
54
|
+
"mistral": ["RBLNMistralForCausalLM"],
|
55
|
+
"phi": ["RBLNPhiForCausalLM"],
|
56
|
+
"qwen2": ["RBLNQwen2ForCausalLM"],
|
57
|
+
"t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
|
58
|
+
"wav2vec2": ["RBLNWav2Vec2ForCTC"],
|
59
|
+
"whisper": ["RBLNWhisperForConditionalGeneration"],
|
60
|
+
"xlm_roberta": ["RBLNXLMRobertaModel"],
|
61
|
+
}
|
62
|
+
|
63
|
+
if TYPE_CHECKING:
|
64
|
+
from .auto import (
|
65
|
+
RBLNAutoModel,
|
66
|
+
RBLNAutoModelForAudioClassification,
|
67
|
+
RBLNAutoModelForCausalLM,
|
68
|
+
RBLNAutoModelForCTC,
|
69
|
+
RBLNAutoModelForDepthEstimation,
|
70
|
+
RBLNAutoModelForImageClassification,
|
71
|
+
RBLNAutoModelForMaskedLM,
|
72
|
+
RBLNAutoModelForQuestionAnswering,
|
73
|
+
RBLNAutoModelForSeq2SeqLM,
|
74
|
+
RBLNAutoModelForSequenceClassification,
|
75
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
76
|
+
RBLNAutoModelForVision2Seq,
|
77
|
+
)
|
78
|
+
from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
|
79
|
+
from .bert import RBLNBertModel
|
80
|
+
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|
81
|
+
from .dpt import RBLNDPTForDepthEstimation
|
82
|
+
from .exaone import RBLNExaoneForCausalLM
|
83
|
+
from .gemma import RBLNGemmaForCausalLM
|
84
|
+
from .gpt2 import RBLNGPT2LMHeadModel
|
85
|
+
from .llama import RBLNLlamaForCausalLM
|
86
|
+
from .llava_next import RBLNLlavaNextForConditionalGeneration
|
87
|
+
from .midm import RBLNMidmLMHeadModel
|
88
|
+
from .mistral import RBLNMistralForCausalLM
|
89
|
+
from .phi import RBLNPhiForCausalLM
|
90
|
+
from .qwen2 import RBLNQwen2ForCausalLM
|
91
|
+
from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
92
|
+
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
93
|
+
from .whisper import RBLNWhisperForConditionalGeneration
|
94
|
+
from .xlm_roberta import RBLNXLMRobertaModel
|
95
|
+
|
96
|
+
else:
|
97
|
+
import sys
|
98
|
+
|
99
|
+
sys.modules[__name__] = _LazyModule(
|
100
|
+
__name__,
|
101
|
+
globals()["__file__"],
|
102
|
+
_import_structure,
|
103
|
+
module_spec=__spec__,
|
104
|
+
)
|
@@ -26,8 +26,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from transformers import (
|
29
|
-
AutoConfig,
|
30
|
-
AutoModel,
|
31
29
|
CLIPTextConfig,
|
32
30
|
CLIPTextModel,
|
33
31
|
CLIPTextModelWithProjection,
|
@@ -39,6 +37,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
|
39
37
|
|
40
38
|
from ....modeling_base import RBLNModel
|
41
39
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
40
|
+
from ....utils.context import override_auto_classes
|
42
41
|
|
43
42
|
|
44
43
|
logger = logging.getLogger(__name__)
|
@@ -58,18 +57,14 @@ class _TextEncoder(torch.nn.Module):
|
|
58
57
|
|
59
58
|
|
60
59
|
class RBLNCLIPTextModel(RBLNModel):
|
61
|
-
original_model_class = CLIPTextModel
|
62
|
-
original_config_class = CLIPTextConfig
|
63
|
-
|
64
60
|
@classmethod
|
65
61
|
def from_pretrained(cls, *args, **kwargs):
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
AutoModel.from_pretrained = modeltmp
|
62
|
+
with override_auto_classes(
|
63
|
+
config_func=CLIPTextConfig.from_pretrained,
|
64
|
+
model_func=CLIPTextModel.from_pretrained,
|
65
|
+
skip_taskmanager=False,
|
66
|
+
):
|
67
|
+
rt = super().from_pretrained(*args, **kwargs)
|
73
68
|
return rt
|
74
69
|
|
75
70
|
@classmethod
|
@@ -133,18 +128,14 @@ class _VisionEncoder(torch.nn.Module):
|
|
133
128
|
|
134
129
|
|
135
130
|
class RBLNCLIPVisionModel(RBLNModel):
|
136
|
-
original_model_class = CLIPVisionModel
|
137
|
-
original_config_class = CLIPVisionConfig
|
138
|
-
|
139
131
|
@classmethod
|
140
132
|
def from_pretrained(cls, *args, **kwargs):
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
AutoModel.from_pretrained = modeltmp
|
133
|
+
with override_auto_classes(
|
134
|
+
config_func=CLIPVisionConfig.from_pretrained,
|
135
|
+
model_func=CLIPVisionModel.from_pretrained,
|
136
|
+
skip_taskmanager=False,
|
137
|
+
):
|
138
|
+
rt = super().from_pretrained(*args, **kwargs)
|
148
139
|
return rt
|
149
140
|
|
150
141
|
@classmethod
|