optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,283 @@
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ def _compute_default_rope_parameters(
9
+ config: Optional[PretrainedConfig] = None,
10
+ seq_len: Optional[int] = None,
11
+ ) -> Tuple["torch.Tensor", float]:
12
+ """
13
+ Computes the inverse frequencies according to the original RoPE implementation
14
+ Args:
15
+ config ([`~transformers.PretrainedConfig`]):
16
+ The model configuration.
17
+ seq_len (`int`, *optional*):
18
+ The current sequence length. Unused for this type of RoPE.
19
+ Returns:
20
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
21
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
22
+ """
23
+
24
+ base = config.rope_theta
25
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
26
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
27
+ dim = int(head_dim * partial_rotary_factor)
28
+
29
+ attention_factor = 1.0 # Unused in this type of RoPE
30
+
31
+ # Compute the inverse frequencies
32
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
33
+ return inv_freq, attention_factor
34
+
35
+
36
+ def _compute_linear_scaling_rope_parameters(
37
+ config: Optional[PretrainedConfig] = None,
38
+ seq_len: Optional[int] = None,
39
+ ) -> Tuple["torch.Tensor", float]:
40
+ """
41
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
42
+ Args:
43
+ config ([`~transformers.PretrainedConfig`]):
44
+ The model configuration.
45
+ seq_len (`int`, *optional*):
46
+ The current sequence length. Unused for this type of RoPE.
47
+ Returns:
48
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
49
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
50
+ """
51
+
52
+ factor = config.rope_scaling["factor"]
53
+
54
+ # Gets the default RoPE parameters
55
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
56
+
57
+ # Then applies linear scaling to the frequencies.
58
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
59
+ # applying scaling to the inverse frequencies is equivalent.
60
+ inv_freq /= factor
61
+ return inv_freq, attention_factor
62
+
63
+
64
+ def _compute_dynamic_ntk_parameters(
65
+ config: Optional[PretrainedConfig] = None,
66
+ seq_len: Optional[int] = None,
67
+ ) -> Tuple["torch.Tensor", float]:
68
+ """
69
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
70
+ Args:
71
+ config ([`~transformers.PretrainedConfig`]):
72
+ The model configuration.
73
+ seq_len (`int`, *optional*):
74
+ The current sequence length, used to update the dynamic RoPE at inference time.
75
+ Returns:
76
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
77
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
78
+ """
79
+
80
+ base = config.rope_theta
81
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
82
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
83
+ dim = int(head_dim * partial_rotary_factor)
84
+ max_position_embeddings = config.max_position_embeddings
85
+ factor = config.rope_scaling["factor"]
86
+
87
+ attention_factor = 1.0 # Unused in this type of RoPE
88
+
89
+ # Process with chunk_size to reduce precesion error
90
+ chunk_size = 4096
91
+ chunks = (seq_len + chunk_size - 1) // chunk_size
92
+
93
+ inv_freq_list = []
94
+ for i in range(chunks):
95
+ start = i * chunk_size
96
+ end = min((i + 1) * chunk_size, seq_len)
97
+
98
+ seq_lens = torch.arange(start, end, dtype=torch.float32).view(-1, 1) + 1.0
99
+ seq_lens = torch.where(seq_lens > max_position_embeddings, seq_lens, max_position_embeddings)
100
+
101
+ # Compute the inverse frequencies for each chunk
102
+ scaled_base = base * ((factor * seq_lens / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
103
+ inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
104
+
105
+ inv_freq_list.append(inv_freq)
106
+
107
+ final_inv_freq = torch.cat(inv_freq_list, dim=0)
108
+
109
+ return final_inv_freq, attention_factor
110
+
111
+
112
+ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
113
+ """
114
+ Computes the inverse frequencies with NTK scaling. Please refer to the
115
+ [original paper](https://arxiv.org/abs/2309.00071)
116
+ Args:
117
+ config ([`~transformers.PretrainedConfig`]):
118
+ The model configuration.
119
+ seq_len (`int`, *optional*):
120
+ The current sequence length. Unused for this type of RoPE.
121
+ Returns:
122
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
123
+ post-processing scaling factor applied to the computed cos/sin.
124
+ """
125
+
126
+ base = config.rope_theta
127
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
128
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
129
+ dim = int(head_dim * partial_rotary_factor)
130
+ max_position_embeddings = config.max_position_embeddings
131
+ factor = config.rope_scaling["factor"]
132
+
133
+ # Sets the attention factor as suggested in the paper
134
+ attention_factor = config.rope_scaling.get("attention_factor")
135
+ if attention_factor is None:
136
+ attention_factor = 0.1 * math.log(factor) + 1.0
137
+
138
+ # Optional config options
139
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
140
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
141
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
142
+
143
+ # Compute the inverse frequencies
144
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
145
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
146
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
147
+
148
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
149
+ """Find dimension range bounds based on rotations"""
150
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
151
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
152
+ return max(low, 0), min(high, dim - 1)
153
+
154
+ def linear_ramp_factor(min, max, dim):
155
+ if min == max:
156
+ max += 0.001 # Prevent singularity
157
+
158
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
159
+ ramp_func = torch.clamp(linear_func, 0, 1)
160
+ return ramp_func
161
+
162
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
163
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
164
+ pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
165
+ inv_freq_extrapolation = 1.0 / pos_freqs
166
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
167
+
168
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
169
+
170
+ # Get n-dimensional rotational scaling corrected for extrapolation
171
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
172
+ inv_freq = (
173
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
174
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
175
+ )
176
+
177
+ return inv_freq, attention_factor
178
+
179
+
180
+ def _compute_longrope_parameters(
181
+ config: PretrainedConfig, seq_len: Optional[int] = None
182
+ ) -> Tuple["torch.Tensor", float]:
183
+ """
184
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
185
+ [original implementation](https://github.com/microsoft/LongRoPE)
186
+ Args:
187
+ config ([`~transformers.PretrainedConfig`]):
188
+ The model configuration.
189
+ seq_len (`int`, *optional*):
190
+ The current sequence length. Unused for this type of RoPE.
191
+ Returns:
192
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
193
+ post-processing scaling factor applied to the computed cos/sin.
194
+ """
195
+
196
+ base = config.rope_theta
197
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
198
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
199
+ dim = int(head_dim * partial_rotary_factor)
200
+ long_factor = config.rope_scaling["long_factor"]
201
+ short_factor = config.rope_scaling["short_factor"]
202
+ factor = config.rope_scaling.get("factor")
203
+ attention_factor = config.rope_scaling.get("attention_factor")
204
+
205
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
206
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
207
+ # values to compute the default attention scaling factor, instead of using `factor`.
208
+ if hasattr(config, "original_max_position_embeddings"):
209
+ max_position_embeddings = config.original_max_position_embeddings
210
+ expanded_max_position_embeddings = config.max_position_embeddings
211
+ factor = expanded_max_position_embeddings / max_position_embeddings
212
+ else:
213
+ max_position_embeddings = config.max_position_embeddings
214
+ expanded_max_position_embeddings = max_position_embeddings * factor
215
+
216
+ # Sets the attention factor as suggested in the paper
217
+ if attention_factor is None:
218
+ if factor <= 1.0:
219
+ attention_factor = 1.0
220
+ else:
221
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
222
+
223
+ # Compute the inverse frequencies -- scaled based on the target sequence length
224
+ if expanded_max_position_embeddings > max_position_embeddings:
225
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32)
226
+ else:
227
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32)
228
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
229
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
230
+
231
+ return inv_freq, attention_factor
232
+
233
+
234
+ def _compute_llama3_parameters(
235
+ config: PretrainedConfig, seq_len: Optional[int] = None
236
+ ) -> Tuple["torch.Tensor", float]:
237
+ """
238
+ Computes the inverse frequencies for llama 3.1.
239
+
240
+ Args:
241
+ config ([`~transformers.PretrainedConfig`]):
242
+ The model configuration.
243
+ seq_len (`int`, *optional*):
244
+ The current sequence length. Unused for this type of RoPE.
245
+ Returns:
246
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
247
+ post-processing scaling factor applied to the computed cos/sin.
248
+ """
249
+ # Gets the default RoPE parameters
250
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
251
+
252
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
253
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
254
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
255
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
256
+
257
+ low_freq_wavelen = old_context_len / low_freq_factor
258
+ high_freq_wavelen = old_context_len / high_freq_factor
259
+
260
+ wavelen = 2 * math.pi / inv_freq
261
+ # wavelen < high_freq_wavelen: do nothing
262
+ # wavelen > low_freq_wavelen: divide by factor
263
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
264
+ # otherwise: interpolate between the two, using a smooth factor
265
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
266
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
267
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
268
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
269
+
270
+ return inv_freq_llama, attention_factor
271
+
272
+
273
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
274
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
275
+ # parameterizations, as long as the callable has the same signature.
276
+ ROPE_INIT_FUNCTIONS = {
277
+ "default": _compute_default_rope_parameters,
278
+ "linear": _compute_linear_scaling_rope_parameters,
279
+ "dynamic": _compute_dynamic_ntk_parameters,
280
+ "yarn": _compute_yarn_parameters,
281
+ "longrope": _compute_longrope_parameters,
282
+ "llama3": _compute_llama3_parameters,
283
+ }
@@ -21,32 +21,84 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ from typing import TYPE_CHECKING
24
25
 
25
- from .auto import (
26
- RBLNAutoModel,
27
- RBLNAutoModelForAudioClassification,
28
- RBLNAutoModelForCausalLM,
29
- RBLNAutoModelForCTC,
30
- RBLNAutoModelForDepthEstimation,
31
- RBLNAutoModelForImageClassification,
32
- RBLNAutoModelForMaskedLM,
33
- RBLNAutoModelForQuestionAnswering,
34
- RBLNAutoModelForSeq2SeqLM,
35
- RBLNAutoModelForSequenceClassification,
36
- RBLNAutoModelForSpeechSeq2Seq,
37
- RBLNAutoModelForVision2Seq,
38
- )
39
- from .bart import RBLNBartModel
40
- from .bert import RBLNBertModel
41
- from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
42
- from .dpt import RBLNDPTForDepthEstimation
43
- from .gemma import RBLNGemmaForCausalLM
44
- from .gpt2 import RBLNGPT2LMHeadModel
45
- from .llama import RBLNLlamaForCausalLM
46
- from .llava_next import RBLNLlavaNextForConditionalGeneration
47
- from .midm import RBLNMidmLMHeadModel
48
- from .mistral import RBLNMistralForCausalLM
49
- from .phi import RBLNPhiForCausalLM
50
- from .wav2vec2 import RBLNWav2Vec2ForCTC
51
- from .whisper import RBLNWhisperForConditionalGeneration
52
- from .xlm_roberta import RBLNXLMRobertaModel
26
+ from transformers.utils import _LazyModule
27
+
28
+
29
+ _import_structure = {
30
+ "auto": [
31
+ "RBLNAutoModel",
32
+ "RBLNAutoModelForAudioClassification",
33
+ "RBLNAutoModelForCausalLM",
34
+ "RBLNAutoModelForCTC",
35
+ "RBLNAutoModelForDepthEstimation",
36
+ "RBLNAutoModelForImageClassification",
37
+ "RBLNAutoModelForMaskedLM",
38
+ "RBLNAutoModelForQuestionAnswering",
39
+ "RBLNAutoModelForSeq2SeqLM",
40
+ "RBLNAutoModelForSequenceClassification",
41
+ "RBLNAutoModelForSpeechSeq2Seq",
42
+ "RBLNAutoModelForVision2Seq",
43
+ ],
44
+ "bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
45
+ "bert": ["RBLNBertModel"],
46
+ "clip": ["RBLNCLIPTextModel", "RBLNCLIPTextModelWithProjection", "RBLNCLIPVisionModel"],
47
+ "dpt": ["RBLNDPTForDepthEstimation"],
48
+ "exaone": ["RBLNExaoneForCausalLM"],
49
+ "gemma": ["RBLNGemmaForCausalLM"],
50
+ "gpt2": ["RBLNGPT2LMHeadModel"],
51
+ "llama": ["RBLNLlamaForCausalLM"],
52
+ "llava_next": ["RBLNLlavaNextForConditionalGeneration"],
53
+ "midm": ["RBLNMidmLMHeadModel"],
54
+ "mistral": ["RBLNMistralForCausalLM"],
55
+ "phi": ["RBLNPhiForCausalLM"],
56
+ "qwen2": ["RBLNQwen2ForCausalLM"],
57
+ "t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
58
+ "wav2vec2": ["RBLNWav2Vec2ForCTC"],
59
+ "whisper": ["RBLNWhisperForConditionalGeneration"],
60
+ "xlm_roberta": ["RBLNXLMRobertaModel"],
61
+ }
62
+
63
+ if TYPE_CHECKING:
64
+ from .auto import (
65
+ RBLNAutoModel,
66
+ RBLNAutoModelForAudioClassification,
67
+ RBLNAutoModelForCausalLM,
68
+ RBLNAutoModelForCTC,
69
+ RBLNAutoModelForDepthEstimation,
70
+ RBLNAutoModelForImageClassification,
71
+ RBLNAutoModelForMaskedLM,
72
+ RBLNAutoModelForQuestionAnswering,
73
+ RBLNAutoModelForSeq2SeqLM,
74
+ RBLNAutoModelForSequenceClassification,
75
+ RBLNAutoModelForSpeechSeq2Seq,
76
+ RBLNAutoModelForVision2Seq,
77
+ )
78
+ from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
79
+ from .bert import RBLNBertModel
80
+ from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
81
+ from .dpt import RBLNDPTForDepthEstimation
82
+ from .exaone import RBLNExaoneForCausalLM
83
+ from .gemma import RBLNGemmaForCausalLM
84
+ from .gpt2 import RBLNGPT2LMHeadModel
85
+ from .llama import RBLNLlamaForCausalLM
86
+ from .llava_next import RBLNLlavaNextForConditionalGeneration
87
+ from .midm import RBLNMidmLMHeadModel
88
+ from .mistral import RBLNMistralForCausalLM
89
+ from .phi import RBLNPhiForCausalLM
90
+ from .qwen2 import RBLNQwen2ForCausalLM
91
+ from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
92
+ from .wav2vec2 import RBLNWav2Vec2ForCTC
93
+ from .whisper import RBLNWhisperForConditionalGeneration
94
+ from .xlm_roberta import RBLNXLMRobertaModel
95
+
96
+ else:
97
+ import sys
98
+
99
+ sys.modules[__name__] = _LazyModule(
100
+ __name__,
101
+ globals()["__file__"],
102
+ _import_structure,
103
+ module_spec=__spec__,
104
+ )
@@ -42,6 +42,7 @@ from .auto_factory import _BaseAutoModelClass
42
42
  MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
43
43
  {
44
44
  "midm": "MidmLMHeadModel",
45
+ "exaone": "ExaoneForCausalLM",
45
46
  }
46
47
  )
47
48
 
@@ -22,4 +22,4 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .bart_architecture import BartDecoderWrapper, BartEncoderWrapper
25
- from .modeling_bart import RBLNBartModel
25
+ from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
@@ -47,6 +47,12 @@ from transformers.utils import logging
47
47
  logger = logging.get_logger(__name__)
48
48
 
49
49
 
50
+ class BartWrapper:
51
+ def __init__(self, model):
52
+ self.encoder = BartEncoderWrapper(model)
53
+ self.decoder = BartDecoderWrapper(model)
54
+
55
+
50
56
  class _BartAttention(BartAttention):
51
57
  def forward(
52
58
  self,
@@ -238,6 +244,7 @@ class _BartSdpaAttention(BartSdpaAttention):
238
244
  value_states, dim=2, start=cache_position, end=cache_position + 1
239
245
  )
240
246
 
247
+ # need 4d shape (input tensors) for scaled_dot_product_attention
241
248
  attn_output = torch.nn.functional.scaled_dot_product_attention(
242
249
  query_states,
243
250
  key_states,
@@ -324,7 +331,6 @@ class _BartDecoder(BartDecoder):
324
331
  attn_impl: str = "eager",
325
332
  ):
326
333
  # embedding
327
- # thkim fix : transformers == 4.44.2 compile
328
334
  if hasattr(self, "embed_scale"):
329
335
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
330
336
  else:
@@ -336,13 +342,15 @@ class _BartDecoder(BartDecoder):
336
342
  hidden_states = inputs_embeds + positions
337
343
  else:
338
344
  hidden_all = []
345
+ # compiler pattern base dependency -> take + add
339
346
  for i in range(input_ids.shape[0]):
340
347
  # cache position [N,1]
341
348
  positions_idx = cache_position[i]
349
+ # offset is set 2 in bart embedding
342
350
  position_weight = self.embed_positions.weight[2:]
343
351
  position = position_weight[positions_idx]
344
- tmp_hidden = position + inputs_embeds[i]
345
- hidden_all.append(tmp_hidden)
352
+ batch_hidden = position + inputs_embeds[i]
353
+ hidden_all.append(batch_hidden)
346
354
  hidden_states = torch.stack(hidden_all, dim=0)
347
355
 
348
356
  hidden_states = self.layernorm_embedding(hidden_states)
@@ -444,6 +452,7 @@ class BartDecoderWrapper(torch.nn.Module):
444
452
  self_kv_cache.append(past_key_values[i][1])
445
453
  self_kv_cache = torch.stack(self_kv_cache, dim=0)
446
454
 
455
+ # return batch_position to keep it as a variable within the graph
447
456
  return lm_logits, self_kv_cache, batch_position
448
457
 
449
458
 
@@ -467,9 +476,6 @@ class BartEncoderWrapper(torch.nn.Module):
467
476
  cross_key_value: torch.Tensor = None,
468
477
  batch_idx: torch.Tensor = None,
469
478
  ) -> Tuple[torch.Tensor]:
470
- encoder_batch_size = input_ids.shape[0]
471
- decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
472
-
473
479
  # 1. run encoder
474
480
  encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
475
481
  last_hidden_states = encoder_outputs[0]
@@ -477,19 +483,19 @@ class BartEncoderWrapper(torch.nn.Module):
477
483
  # 2. run dummy decoder to get pre-calculated cross-key_values for generation
478
484
  dummy_past_key_value = []
479
485
  for _ in range(self.num_layers):
480
- pkv_self_attn_key = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
481
- pkv_self_attn_value = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
482
- pkv_cross_attn_key = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
483
- pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
486
+ pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
487
+ pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
488
+ pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
489
+ pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
484
490
  layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
485
491
  dummy_past_key_value.append(layer_pkv)
486
492
 
487
- decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.float32)
493
+ decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
488
494
  decoder_attention_mask[:, :1] = 1
489
495
 
490
496
  decoder_outputs = _BartDecoder.forward(
491
497
  self.decoder,
492
- input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
498
+ input_ids=torch.zeros((1, 1), dtype=torch.int64),
493
499
  attention_mask=decoder_attention_mask,
494
500
  encoder_attention_mask=attention_mask,
495
501
  cache_position=torch.tensor(0, dtype=torch.int32),
@@ -22,23 +22,25 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
27
26
 
28
- from transformers import AutoModel, BartConfig, BartModel, PretrainedConfig
27
+ from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
29
28
 
30
29
  from ....modeling_base import RBLNModel
31
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
+ from ....utils.logging import get_logger
32
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
33
+ from .bart_architecture import BartWrapper
32
34
 
33
35
 
34
- logger = logging.getLogger(__name__)
36
+ logger = get_logger()
37
+
35
38
 
36
39
  if TYPE_CHECKING:
37
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
40
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
38
41
 
39
42
 
40
43
  class RBLNBartModel(RBLNModel):
41
- auto_model_class = AutoModel # feature extraction
42
44
  original_model_class = BartModel
43
45
  original_config_class = BartConfig
44
46
 
@@ -104,3 +106,20 @@ class RBLNBartModel(RBLNModel):
104
106
 
105
107
  rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
106
108
  return rbln_config
109
+
110
+
111
+ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
112
+ @classmethod
113
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
114
+ return BartWrapper(model)
115
+
116
+ def __getattr__(self, __name: str) -> Any:
117
+ def redirect(func):
118
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
119
+
120
+ val = getattr(BartForConditionalGeneration, __name)
121
+
122
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
123
+ return redirect(val)
124
+
125
+ return val
@@ -25,7 +25,7 @@ import inspect
25
25
  import logging
26
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
27
 
28
- from transformers import AutoModel, BertConfig, BertModel, PretrainedConfig
28
+ from transformers import BertConfig, BertModel, PretrainedConfig
29
29
 
30
30
  from ....modeling_base import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNBertModel(RBLNModel):
41
- auto_model_class = AutoModel # feature extraction
42
41
  original_model_class = BertModel
43
42
  original_config_class = BertConfig
44
43
 
@@ -26,8 +26,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
28
  from transformers import (
29
- AutoConfig,
30
- AutoModel,
31
29
  CLIPTextConfig,
32
30
  CLIPTextModel,
33
31
  CLIPTextModelWithProjection,
@@ -39,6 +37,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
39
37
 
40
38
  from ....modeling_base import RBLNModel
41
39
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
40
+ from ....utils.context import override_auto_classes
42
41
 
43
42
 
44
43
  logger = logging.getLogger(__name__)
@@ -58,19 +57,14 @@ class _TextEncoder(torch.nn.Module):
58
57
 
59
58
 
60
59
  class RBLNCLIPTextModel(RBLNModel):
61
- auto_model_class = AutoModel # feature extraction
62
- original_model_class = CLIPTextModel
63
- original_config_class = CLIPTextConfig
64
-
65
60
  @classmethod
66
61
  def from_pretrained(cls, *args, **kwargs):
67
- configtmp = AutoConfig.from_pretrained
68
- modeltmp = AutoModel.from_pretrained
69
- AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
70
- AutoModel.from_pretrained = cls.original_model_class.from_pretrained
71
- rt = super().from_pretrained(*args, **kwargs)
72
- AutoConfig.from_pretrained = configtmp
73
- AutoModel.from_pretrained = modeltmp
62
+ with override_auto_classes(
63
+ config_func=CLIPTextConfig.from_pretrained,
64
+ model_func=CLIPTextModel.from_pretrained,
65
+ skip_taskmanager=False,
66
+ ):
67
+ rt = super().from_pretrained(*args, **kwargs)
74
68
  return rt
75
69
 
76
70
  @classmethod
@@ -134,18 +128,14 @@ class _VisionEncoder(torch.nn.Module):
134
128
 
135
129
 
136
130
  class RBLNCLIPVisionModel(RBLNModel):
137
- original_model_class = CLIPVisionModel
138
- original_config_class = CLIPVisionConfig
139
-
140
131
  @classmethod
141
132
  def from_pretrained(cls, *args, **kwargs):
142
- configtmp = AutoConfig.from_pretrained
143
- modeltmp = AutoModel.from_pretrained
144
- AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
145
- AutoModel.from_pretrained = cls.original_model_class.from_pretrained
146
- rt = super().from_pretrained(*args, **kwargs)
147
- AutoConfig.from_pretrained = configtmp
148
- AutoModel.from_pretrained = modeltmp
133
+ with override_auto_classes(
134
+ config_func=CLIPVisionConfig.from_pretrained,
135
+ model_func=CLIPVisionModel.from_pretrained,
136
+ skip_taskmanager=False,
137
+ ):
138
+ rt = super().from_pretrained(*args, **kwargs)
149
139
  return rt
150
140
 
151
141
  @classmethod
@@ -26,8 +26,6 @@ from .decoderonly_architecture import (
26
26
  DecoderOnlyDecoderLayer,
27
27
  DecoderOnlyModel,
28
28
  DecoderOnlyWrapper,
29
- DynamicNTKScalingRotaryEmbedding,
30
- LinearScalingRotaryEmbedding,
31
29
  RotaryEmbedding,
32
30
  apply_rotary_pos_emb,
33
31
  rotate_half,