optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
- optimum/rbln/diffusers/models/controlnet.py +36 -62
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +117 -144
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -28
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
- optimum_rbln-0.1.13.dist-info/RECORD +107 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/RECORD +0 -93
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Optional, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from transformers import PretrainedConfig
|
6
|
+
|
7
|
+
|
8
|
+
def _compute_default_rope_parameters(
|
9
|
+
config: Optional[PretrainedConfig] = None,
|
10
|
+
seq_len: Optional[int] = None,
|
11
|
+
) -> Tuple["torch.Tensor", float]:
|
12
|
+
"""
|
13
|
+
Computes the inverse frequencies according to the original RoPE implementation
|
14
|
+
Args:
|
15
|
+
config ([`~transformers.PretrainedConfig`]):
|
16
|
+
The model configuration.
|
17
|
+
seq_len (`int`, *optional*):
|
18
|
+
The current sequence length. Unused for this type of RoPE.
|
19
|
+
Returns:
|
20
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
21
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
22
|
+
"""
|
23
|
+
|
24
|
+
base = config.rope_theta
|
25
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
26
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
27
|
+
dim = int(head_dim * partial_rotary_factor)
|
28
|
+
|
29
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
30
|
+
|
31
|
+
# Compute the inverse frequencies
|
32
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
33
|
+
return inv_freq, attention_factor
|
34
|
+
|
35
|
+
|
36
|
+
def _compute_linear_scaling_rope_parameters(
|
37
|
+
config: Optional[PretrainedConfig] = None,
|
38
|
+
seq_len: Optional[int] = None,
|
39
|
+
) -> Tuple["torch.Tensor", float]:
|
40
|
+
"""
|
41
|
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
42
|
+
Args:
|
43
|
+
config ([`~transformers.PretrainedConfig`]):
|
44
|
+
The model configuration.
|
45
|
+
seq_len (`int`, *optional*):
|
46
|
+
The current sequence length. Unused for this type of RoPE.
|
47
|
+
Returns:
|
48
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
49
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
50
|
+
"""
|
51
|
+
|
52
|
+
factor = config.rope_scaling["factor"]
|
53
|
+
|
54
|
+
# Gets the default RoPE parameters
|
55
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
56
|
+
|
57
|
+
# Then applies linear scaling to the frequencies.
|
58
|
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
59
|
+
# applying scaling to the inverse frequencies is equivalent.
|
60
|
+
inv_freq /= factor
|
61
|
+
return inv_freq, attention_factor
|
62
|
+
|
63
|
+
|
64
|
+
def _compute_dynamic_ntk_parameters(
|
65
|
+
config: Optional[PretrainedConfig] = None,
|
66
|
+
seq_len: Optional[int] = None,
|
67
|
+
) -> Tuple["torch.Tensor", float]:
|
68
|
+
"""
|
69
|
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
70
|
+
Args:
|
71
|
+
config ([`~transformers.PretrainedConfig`]):
|
72
|
+
The model configuration.
|
73
|
+
seq_len (`int`, *optional*):
|
74
|
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
75
|
+
Returns:
|
76
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
77
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
78
|
+
"""
|
79
|
+
|
80
|
+
base = config.rope_theta
|
81
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
82
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
83
|
+
dim = int(head_dim * partial_rotary_factor)
|
84
|
+
max_position_embeddings = config.max_position_embeddings
|
85
|
+
factor = config.rope_scaling["factor"]
|
86
|
+
|
87
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
88
|
+
|
89
|
+
# Process with chunk_size to reduce precesion error
|
90
|
+
chunk_size = 4096
|
91
|
+
chunks = (seq_len + chunk_size - 1) // chunk_size
|
92
|
+
|
93
|
+
inv_freq_list = []
|
94
|
+
for i in range(chunks):
|
95
|
+
start = i * chunk_size
|
96
|
+
end = min((i + 1) * chunk_size, seq_len)
|
97
|
+
|
98
|
+
seq_lens = torch.arange(start, end, dtype=torch.float32).view(-1, 1) + 1.0
|
99
|
+
seq_lens = torch.where(seq_lens > max_position_embeddings, seq_lens, max_position_embeddings)
|
100
|
+
|
101
|
+
# Compute the inverse frequencies for each chunk
|
102
|
+
scaled_base = base * ((factor * seq_lens / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
103
|
+
inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
104
|
+
|
105
|
+
inv_freq_list.append(inv_freq)
|
106
|
+
|
107
|
+
final_inv_freq = torch.cat(inv_freq_list, dim=0)
|
108
|
+
|
109
|
+
return final_inv_freq, attention_factor
|
110
|
+
|
111
|
+
|
112
|
+
def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
|
113
|
+
"""
|
114
|
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
115
|
+
[original paper](https://arxiv.org/abs/2309.00071)
|
116
|
+
Args:
|
117
|
+
config ([`~transformers.PretrainedConfig`]):
|
118
|
+
The model configuration.
|
119
|
+
seq_len (`int`, *optional*):
|
120
|
+
The current sequence length. Unused for this type of RoPE.
|
121
|
+
Returns:
|
122
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
123
|
+
post-processing scaling factor applied to the computed cos/sin.
|
124
|
+
"""
|
125
|
+
|
126
|
+
base = config.rope_theta
|
127
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
128
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
129
|
+
dim = int(head_dim * partial_rotary_factor)
|
130
|
+
max_position_embeddings = config.max_position_embeddings
|
131
|
+
factor = config.rope_scaling["factor"]
|
132
|
+
|
133
|
+
# Sets the attention factor as suggested in the paper
|
134
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
135
|
+
if attention_factor is None:
|
136
|
+
attention_factor = 0.1 * math.log(factor) + 1.0
|
137
|
+
|
138
|
+
# Optional config options
|
139
|
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
140
|
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
141
|
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
142
|
+
|
143
|
+
# Compute the inverse frequencies
|
144
|
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
145
|
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
146
|
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
147
|
+
|
148
|
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
149
|
+
"""Find dimension range bounds based on rotations"""
|
150
|
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
151
|
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
152
|
+
return max(low, 0), min(high, dim - 1)
|
153
|
+
|
154
|
+
def linear_ramp_factor(min, max, dim):
|
155
|
+
if min == max:
|
156
|
+
max += 0.001 # Prevent singularity
|
157
|
+
|
158
|
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
159
|
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
160
|
+
return ramp_func
|
161
|
+
|
162
|
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
163
|
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
164
|
+
pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
|
165
|
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
166
|
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
167
|
+
|
168
|
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
169
|
+
|
170
|
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
171
|
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
|
172
|
+
inv_freq = (
|
173
|
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
174
|
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
175
|
+
)
|
176
|
+
|
177
|
+
return inv_freq, attention_factor
|
178
|
+
|
179
|
+
|
180
|
+
def _compute_longrope_parameters(
|
181
|
+
config: PretrainedConfig, seq_len: Optional[int] = None
|
182
|
+
) -> Tuple["torch.Tensor", float]:
|
183
|
+
"""
|
184
|
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
185
|
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
186
|
+
Args:
|
187
|
+
config ([`~transformers.PretrainedConfig`]):
|
188
|
+
The model configuration.
|
189
|
+
seq_len (`int`, *optional*):
|
190
|
+
The current sequence length. Unused for this type of RoPE.
|
191
|
+
Returns:
|
192
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
193
|
+
post-processing scaling factor applied to the computed cos/sin.
|
194
|
+
"""
|
195
|
+
|
196
|
+
base = config.rope_theta
|
197
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
198
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
199
|
+
dim = int(head_dim * partial_rotary_factor)
|
200
|
+
long_factor = config.rope_scaling["long_factor"]
|
201
|
+
short_factor = config.rope_scaling["short_factor"]
|
202
|
+
factor = config.rope_scaling.get("factor")
|
203
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
204
|
+
|
205
|
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
206
|
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
207
|
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
208
|
+
if hasattr(config, "original_max_position_embeddings"):
|
209
|
+
max_position_embeddings = config.original_max_position_embeddings
|
210
|
+
expanded_max_position_embeddings = config.max_position_embeddings
|
211
|
+
factor = expanded_max_position_embeddings / max_position_embeddings
|
212
|
+
else:
|
213
|
+
max_position_embeddings = config.max_position_embeddings
|
214
|
+
expanded_max_position_embeddings = max_position_embeddings * factor
|
215
|
+
|
216
|
+
# Sets the attention factor as suggested in the paper
|
217
|
+
if attention_factor is None:
|
218
|
+
if factor <= 1.0:
|
219
|
+
attention_factor = 1.0
|
220
|
+
else:
|
221
|
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
222
|
+
|
223
|
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
224
|
+
if expanded_max_position_embeddings > max_position_embeddings:
|
225
|
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32)
|
226
|
+
else:
|
227
|
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32)
|
228
|
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
|
229
|
+
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
230
|
+
|
231
|
+
return inv_freq, attention_factor
|
232
|
+
|
233
|
+
|
234
|
+
def _compute_llama3_parameters(
|
235
|
+
config: PretrainedConfig, seq_len: Optional[int] = None
|
236
|
+
) -> Tuple["torch.Tensor", float]:
|
237
|
+
"""
|
238
|
+
Computes the inverse frequencies for llama 3.1.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
config ([`~transformers.PretrainedConfig`]):
|
242
|
+
The model configuration.
|
243
|
+
seq_len (`int`, *optional*):
|
244
|
+
The current sequence length. Unused for this type of RoPE.
|
245
|
+
Returns:
|
246
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
247
|
+
post-processing scaling factor applied to the computed cos/sin.
|
248
|
+
"""
|
249
|
+
# Gets the default RoPE parameters
|
250
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
251
|
+
|
252
|
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
253
|
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
254
|
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
255
|
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
256
|
+
|
257
|
+
low_freq_wavelen = old_context_len / low_freq_factor
|
258
|
+
high_freq_wavelen = old_context_len / high_freq_factor
|
259
|
+
|
260
|
+
wavelen = 2 * math.pi / inv_freq
|
261
|
+
# wavelen < high_freq_wavelen: do nothing
|
262
|
+
# wavelen > low_freq_wavelen: divide by factor
|
263
|
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
264
|
+
# otherwise: interpolate between the two, using a smooth factor
|
265
|
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
266
|
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
267
|
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
268
|
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
269
|
+
|
270
|
+
return inv_freq_llama, attention_factor
|
271
|
+
|
272
|
+
|
273
|
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
274
|
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
275
|
+
# parameterizations, as long as the callable has the same signature.
|
276
|
+
ROPE_INIT_FUNCTIONS = {
|
277
|
+
"default": _compute_default_rope_parameters,
|
278
|
+
"linear": _compute_linear_scaling_rope_parameters,
|
279
|
+
"dynamic": _compute_dynamic_ntk_parameters,
|
280
|
+
"yarn": _compute_yarn_parameters,
|
281
|
+
"longrope": _compute_longrope_parameters,
|
282
|
+
"llama3": _compute_llama3_parameters,
|
283
|
+
}
|
@@ -21,32 +21,84 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
from typing import TYPE_CHECKING
|
24
25
|
|
25
|
-
from .
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
26
|
+
from transformers.utils import _LazyModule
|
27
|
+
|
28
|
+
|
29
|
+
_import_structure = {
|
30
|
+
"auto": [
|
31
|
+
"RBLNAutoModel",
|
32
|
+
"RBLNAutoModelForAudioClassification",
|
33
|
+
"RBLNAutoModelForCausalLM",
|
34
|
+
"RBLNAutoModelForCTC",
|
35
|
+
"RBLNAutoModelForDepthEstimation",
|
36
|
+
"RBLNAutoModelForImageClassification",
|
37
|
+
"RBLNAutoModelForMaskedLM",
|
38
|
+
"RBLNAutoModelForQuestionAnswering",
|
39
|
+
"RBLNAutoModelForSeq2SeqLM",
|
40
|
+
"RBLNAutoModelForSequenceClassification",
|
41
|
+
"RBLNAutoModelForSpeechSeq2Seq",
|
42
|
+
"RBLNAutoModelForVision2Seq",
|
43
|
+
],
|
44
|
+
"bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
|
45
|
+
"bert": ["RBLNBertModel"],
|
46
|
+
"clip": ["RBLNCLIPTextModel", "RBLNCLIPTextModelWithProjection", "RBLNCLIPVisionModel"],
|
47
|
+
"dpt": ["RBLNDPTForDepthEstimation"],
|
48
|
+
"exaone": ["RBLNExaoneForCausalLM"],
|
49
|
+
"gemma": ["RBLNGemmaForCausalLM"],
|
50
|
+
"gpt2": ["RBLNGPT2LMHeadModel"],
|
51
|
+
"llama": ["RBLNLlamaForCausalLM"],
|
52
|
+
"llava_next": ["RBLNLlavaNextForConditionalGeneration"],
|
53
|
+
"midm": ["RBLNMidmLMHeadModel"],
|
54
|
+
"mistral": ["RBLNMistralForCausalLM"],
|
55
|
+
"phi": ["RBLNPhiForCausalLM"],
|
56
|
+
"qwen2": ["RBLNQwen2ForCausalLM"],
|
57
|
+
"t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
|
58
|
+
"wav2vec2": ["RBLNWav2Vec2ForCTC"],
|
59
|
+
"whisper": ["RBLNWhisperForConditionalGeneration"],
|
60
|
+
"xlm_roberta": ["RBLNXLMRobertaModel"],
|
61
|
+
}
|
62
|
+
|
63
|
+
if TYPE_CHECKING:
|
64
|
+
from .auto import (
|
65
|
+
RBLNAutoModel,
|
66
|
+
RBLNAutoModelForAudioClassification,
|
67
|
+
RBLNAutoModelForCausalLM,
|
68
|
+
RBLNAutoModelForCTC,
|
69
|
+
RBLNAutoModelForDepthEstimation,
|
70
|
+
RBLNAutoModelForImageClassification,
|
71
|
+
RBLNAutoModelForMaskedLM,
|
72
|
+
RBLNAutoModelForQuestionAnswering,
|
73
|
+
RBLNAutoModelForSeq2SeqLM,
|
74
|
+
RBLNAutoModelForSequenceClassification,
|
75
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
76
|
+
RBLNAutoModelForVision2Seq,
|
77
|
+
)
|
78
|
+
from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
|
79
|
+
from .bert import RBLNBertModel
|
80
|
+
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|
81
|
+
from .dpt import RBLNDPTForDepthEstimation
|
82
|
+
from .exaone import RBLNExaoneForCausalLM
|
83
|
+
from .gemma import RBLNGemmaForCausalLM
|
84
|
+
from .gpt2 import RBLNGPT2LMHeadModel
|
85
|
+
from .llama import RBLNLlamaForCausalLM
|
86
|
+
from .llava_next import RBLNLlavaNextForConditionalGeneration
|
87
|
+
from .midm import RBLNMidmLMHeadModel
|
88
|
+
from .mistral import RBLNMistralForCausalLM
|
89
|
+
from .phi import RBLNPhiForCausalLM
|
90
|
+
from .qwen2 import RBLNQwen2ForCausalLM
|
91
|
+
from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
92
|
+
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
93
|
+
from .whisper import RBLNWhisperForConditionalGeneration
|
94
|
+
from .xlm_roberta import RBLNXLMRobertaModel
|
95
|
+
|
96
|
+
else:
|
97
|
+
import sys
|
98
|
+
|
99
|
+
sys.modules[__name__] = _LazyModule(
|
100
|
+
__name__,
|
101
|
+
globals()["__file__"],
|
102
|
+
_import_structure,
|
103
|
+
module_spec=__spec__,
|
104
|
+
)
|
@@ -47,6 +47,12 @@ from transformers.utils import logging
|
|
47
47
|
logger = logging.get_logger(__name__)
|
48
48
|
|
49
49
|
|
50
|
+
class BartWrapper:
|
51
|
+
def __init__(self, model):
|
52
|
+
self.encoder = BartEncoderWrapper(model)
|
53
|
+
self.decoder = BartDecoderWrapper(model)
|
54
|
+
|
55
|
+
|
50
56
|
class _BartAttention(BartAttention):
|
51
57
|
def forward(
|
52
58
|
self,
|
@@ -238,6 +244,7 @@ class _BartSdpaAttention(BartSdpaAttention):
|
|
238
244
|
value_states, dim=2, start=cache_position, end=cache_position + 1
|
239
245
|
)
|
240
246
|
|
247
|
+
# need 4d shape (input tensors) for scaled_dot_product_attention
|
241
248
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
242
249
|
query_states,
|
243
250
|
key_states,
|
@@ -324,7 +331,6 @@ class _BartDecoder(BartDecoder):
|
|
324
331
|
attn_impl: str = "eager",
|
325
332
|
):
|
326
333
|
# embedding
|
327
|
-
# thkim fix : transformers == 4.44.2 compile
|
328
334
|
if hasattr(self, "embed_scale"):
|
329
335
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
330
336
|
else:
|
@@ -336,13 +342,15 @@ class _BartDecoder(BartDecoder):
|
|
336
342
|
hidden_states = inputs_embeds + positions
|
337
343
|
else:
|
338
344
|
hidden_all = []
|
345
|
+
# compiler pattern base dependency -> take + add
|
339
346
|
for i in range(input_ids.shape[0]):
|
340
347
|
# cache position [N,1]
|
341
348
|
positions_idx = cache_position[i]
|
349
|
+
# offset is set 2 in bart embedding
|
342
350
|
position_weight = self.embed_positions.weight[2:]
|
343
351
|
position = position_weight[positions_idx]
|
344
|
-
|
345
|
-
hidden_all.append(
|
352
|
+
batch_hidden = position + inputs_embeds[i]
|
353
|
+
hidden_all.append(batch_hidden)
|
346
354
|
hidden_states = torch.stack(hidden_all, dim=0)
|
347
355
|
|
348
356
|
hidden_states = self.layernorm_embedding(hidden_states)
|
@@ -444,6 +452,7 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
444
452
|
self_kv_cache.append(past_key_values[i][1])
|
445
453
|
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
446
454
|
|
455
|
+
# return batch_position to keep it as a variable within the graph
|
447
456
|
return lm_logits, self_kv_cache, batch_position
|
448
457
|
|
449
458
|
|
@@ -467,9 +476,6 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
467
476
|
cross_key_value: torch.Tensor = None,
|
468
477
|
batch_idx: torch.Tensor = None,
|
469
478
|
) -> Tuple[torch.Tensor]:
|
470
|
-
encoder_batch_size = input_ids.shape[0]
|
471
|
-
decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
|
472
|
-
|
473
479
|
# 1. run encoder
|
474
480
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
475
481
|
last_hidden_states = encoder_outputs[0]
|
@@ -477,19 +483,19 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
477
483
|
# 2. run dummy decoder to get pre-calculated cross-key_values for generation
|
478
484
|
dummy_past_key_value = []
|
479
485
|
for _ in range(self.num_layers):
|
480
|
-
pkv_self_attn_key = torch.zeros(
|
481
|
-
pkv_self_attn_value = torch.zeros(
|
482
|
-
pkv_cross_attn_key = torch.zeros(
|
483
|
-
pkv_cross_attn_value = torch.zeros(
|
486
|
+
pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
487
|
+
pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
488
|
+
pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
489
|
+
pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
484
490
|
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
485
491
|
dummy_past_key_value.append(layer_pkv)
|
486
492
|
|
487
|
-
decoder_attention_mask = torch.zeros(
|
493
|
+
decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
|
488
494
|
decoder_attention_mask[:, :1] = 1
|
489
495
|
|
490
496
|
decoder_outputs = _BartDecoder.forward(
|
491
497
|
self.decoder,
|
492
|
-
input_ids=torch.zeros((
|
498
|
+
input_ids=torch.zeros((1, 1), dtype=torch.int64),
|
493
499
|
attention_mask=decoder_attention_mask,
|
494
500
|
encoder_attention_mask=attention_mask,
|
495
501
|
cache_position=torch.tensor(0, dtype=torch.int32),
|
@@ -22,23 +22,25 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import inspect
|
25
|
-
import
|
26
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
25
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
27
26
|
|
28
|
-
from transformers import
|
27
|
+
from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
|
29
28
|
|
30
29
|
from ....modeling_base import RBLNModel
|
31
30
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
|
+
from ....utils.logging import get_logger
|
32
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
33
|
+
from .bart_architecture import BartWrapper
|
32
34
|
|
33
35
|
|
34
|
-
logger =
|
36
|
+
logger = get_logger()
|
37
|
+
|
35
38
|
|
36
39
|
if TYPE_CHECKING:
|
37
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
40
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
38
41
|
|
39
42
|
|
40
43
|
class RBLNBartModel(RBLNModel):
|
41
|
-
auto_model_class = AutoModel # feature extraction
|
42
44
|
original_model_class = BartModel
|
43
45
|
original_config_class = BartConfig
|
44
46
|
|
@@ -104,3 +106,20 @@ class RBLNBartModel(RBLNModel):
|
|
104
106
|
|
105
107
|
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
106
108
|
return rbln_config
|
109
|
+
|
110
|
+
|
111
|
+
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
112
|
+
@classmethod
|
113
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
114
|
+
return BartWrapper(model)
|
115
|
+
|
116
|
+
def __getattr__(self, __name: str) -> Any:
|
117
|
+
def redirect(func):
|
118
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
119
|
+
|
120
|
+
val = getattr(BartForConditionalGeneration, __name)
|
121
|
+
|
122
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
123
|
+
return redirect(val)
|
124
|
+
|
125
|
+
return val
|
@@ -25,7 +25,7 @@ import inspect
|
|
25
25
|
import logging
|
26
26
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
27
|
|
28
|
-
from transformers import
|
28
|
+
from transformers import BertConfig, BertModel, PretrainedConfig
|
29
29
|
|
30
30
|
from ....modeling_base import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
|
|
38
38
|
|
39
39
|
|
40
40
|
class RBLNBertModel(RBLNModel):
|
41
|
-
auto_model_class = AutoModel # feature extraction
|
42
41
|
original_model_class = BertModel
|
43
42
|
original_config_class = BertConfig
|
44
43
|
|
@@ -26,8 +26,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from transformers import (
|
29
|
-
AutoConfig,
|
30
|
-
AutoModel,
|
31
29
|
CLIPTextConfig,
|
32
30
|
CLIPTextModel,
|
33
31
|
CLIPTextModelWithProjection,
|
@@ -39,6 +37,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
|
39
37
|
|
40
38
|
from ....modeling_base import RBLNModel
|
41
39
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
40
|
+
from ....utils.context import override_auto_classes
|
42
41
|
|
43
42
|
|
44
43
|
logger = logging.getLogger(__name__)
|
@@ -58,19 +57,14 @@ class _TextEncoder(torch.nn.Module):
|
|
58
57
|
|
59
58
|
|
60
59
|
class RBLNCLIPTextModel(RBLNModel):
|
61
|
-
auto_model_class = AutoModel # feature extraction
|
62
|
-
original_model_class = CLIPTextModel
|
63
|
-
original_config_class = CLIPTextConfig
|
64
|
-
|
65
60
|
@classmethod
|
66
61
|
def from_pretrained(cls, *args, **kwargs):
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
AutoModel.from_pretrained = modeltmp
|
62
|
+
with override_auto_classes(
|
63
|
+
config_func=CLIPTextConfig.from_pretrained,
|
64
|
+
model_func=CLIPTextModel.from_pretrained,
|
65
|
+
skip_taskmanager=False,
|
66
|
+
):
|
67
|
+
rt = super().from_pretrained(*args, **kwargs)
|
74
68
|
return rt
|
75
69
|
|
76
70
|
@classmethod
|
@@ -134,18 +128,14 @@ class _VisionEncoder(torch.nn.Module):
|
|
134
128
|
|
135
129
|
|
136
130
|
class RBLNCLIPVisionModel(RBLNModel):
|
137
|
-
original_model_class = CLIPVisionModel
|
138
|
-
original_config_class = CLIPVisionConfig
|
139
|
-
|
140
131
|
@classmethod
|
141
132
|
def from_pretrained(cls, *args, **kwargs):
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
AutoModel.from_pretrained = modeltmp
|
133
|
+
with override_auto_classes(
|
134
|
+
config_func=CLIPVisionConfig.from_pretrained,
|
135
|
+
model_func=CLIPVisionModel.from_pretrained,
|
136
|
+
skip_taskmanager=False,
|
137
|
+
):
|
138
|
+
rt = super().from_pretrained(*args, **kwargs)
|
149
139
|
return rt
|
150
140
|
|
151
141
|
@classmethod
|