optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +5 -1
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
  4. optimum/rbln/diffusers/models/controlnet.py +36 -56
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
  16. optimum/rbln/modeling_base.py +12 -5
  17. optimum/rbln/modeling_diffusers.py +400 -0
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/cache_utils.py +5 -9
  20. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  21. optimum/rbln/transformers/models/__init__.py +80 -31
  22. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
  23. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  25. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
  26. optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
  27. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  29. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  30. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  31. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  32. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
  33. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
  35. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  36. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  37. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  38. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  39. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  40. optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  42. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  43. optimum/rbln/utils/context.py +58 -0
  44. optimum/rbln/utils/decorator_utils.py +55 -0
  45. optimum/rbln/utils/import_utils.py +7 -0
  46. optimum/rbln/utils/runtime_utils.py +4 -4
  47. optimum/rbln/utils/timer_utils.py +2 -2
  48. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
  49. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
  50. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,283 @@
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ def _compute_default_rope_parameters(
9
+ config: Optional[PretrainedConfig] = None,
10
+ seq_len: Optional[int] = None,
11
+ ) -> Tuple["torch.Tensor", float]:
12
+ """
13
+ Computes the inverse frequencies according to the original RoPE implementation
14
+ Args:
15
+ config ([`~transformers.PretrainedConfig`]):
16
+ The model configuration.
17
+ seq_len (`int`, *optional*):
18
+ The current sequence length. Unused for this type of RoPE.
19
+ Returns:
20
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
21
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
22
+ """
23
+
24
+ base = config.rope_theta
25
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
26
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
27
+ dim = int(head_dim * partial_rotary_factor)
28
+
29
+ attention_factor = 1.0 # Unused in this type of RoPE
30
+
31
+ # Compute the inverse frequencies
32
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
33
+ return inv_freq, attention_factor
34
+
35
+
36
+ def _compute_linear_scaling_rope_parameters(
37
+ config: Optional[PretrainedConfig] = None,
38
+ seq_len: Optional[int] = None,
39
+ ) -> Tuple["torch.Tensor", float]:
40
+ """
41
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
42
+ Args:
43
+ config ([`~transformers.PretrainedConfig`]):
44
+ The model configuration.
45
+ seq_len (`int`, *optional*):
46
+ The current sequence length. Unused for this type of RoPE.
47
+ Returns:
48
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
49
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
50
+ """
51
+
52
+ factor = config.rope_scaling["factor"]
53
+
54
+ # Gets the default RoPE parameters
55
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
56
+
57
+ # Then applies linear scaling to the frequencies.
58
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
59
+ # applying scaling to the inverse frequencies is equivalent.
60
+ inv_freq /= factor
61
+ return inv_freq, attention_factor
62
+
63
+
64
+ def _compute_dynamic_ntk_parameters(
65
+ config: Optional[PretrainedConfig] = None,
66
+ seq_len: Optional[int] = None,
67
+ ) -> Tuple["torch.Tensor", float]:
68
+ """
69
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
70
+ Args:
71
+ config ([`~transformers.PretrainedConfig`]):
72
+ The model configuration.
73
+ seq_len (`int`, *optional*):
74
+ The current sequence length, used to update the dynamic RoPE at inference time.
75
+ Returns:
76
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
77
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
78
+ """
79
+
80
+ base = config.rope_theta
81
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
82
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
83
+ dim = int(head_dim * partial_rotary_factor)
84
+ max_position_embeddings = config.max_position_embeddings
85
+ factor = config.rope_scaling["factor"]
86
+
87
+ attention_factor = 1.0 # Unused in this type of RoPE
88
+
89
+ # Process with chunk_size to reduce precesion error
90
+ chunk_size = 4096
91
+ chunks = (seq_len + chunk_size - 1) // chunk_size
92
+
93
+ inv_freq_list = []
94
+ for i in range(chunks):
95
+ start = i * chunk_size
96
+ end = min((i + 1) * chunk_size, seq_len)
97
+
98
+ seq_lens = torch.arange(start, end, dtype=torch.float32).view(-1, 1) + 1.0
99
+ seq_lens = torch.where(seq_lens > max_position_embeddings, seq_lens, max_position_embeddings)
100
+
101
+ # Compute the inverse frequencies for each chunk
102
+ scaled_base = base * ((factor * seq_lens / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
103
+ inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
104
+
105
+ inv_freq_list.append(inv_freq)
106
+
107
+ final_inv_freq = torch.cat(inv_freq_list, dim=0)
108
+
109
+ return final_inv_freq, attention_factor
110
+
111
+
112
+ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
113
+ """
114
+ Computes the inverse frequencies with NTK scaling. Please refer to the
115
+ [original paper](https://arxiv.org/abs/2309.00071)
116
+ Args:
117
+ config ([`~transformers.PretrainedConfig`]):
118
+ The model configuration.
119
+ seq_len (`int`, *optional*):
120
+ The current sequence length. Unused for this type of RoPE.
121
+ Returns:
122
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
123
+ post-processing scaling factor applied to the computed cos/sin.
124
+ """
125
+
126
+ base = config.rope_theta
127
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
128
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
129
+ dim = int(head_dim * partial_rotary_factor)
130
+ max_position_embeddings = config.max_position_embeddings
131
+ factor = config.rope_scaling["factor"]
132
+
133
+ # Sets the attention factor as suggested in the paper
134
+ attention_factor = config.rope_scaling.get("attention_factor")
135
+ if attention_factor is None:
136
+ attention_factor = 0.1 * math.log(factor) + 1.0
137
+
138
+ # Optional config options
139
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
140
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
141
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
142
+
143
+ # Compute the inverse frequencies
144
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
145
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
146
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
147
+
148
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
149
+ """Find dimension range bounds based on rotations"""
150
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
151
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
152
+ return max(low, 0), min(high, dim - 1)
153
+
154
+ def linear_ramp_factor(min, max, dim):
155
+ if min == max:
156
+ max += 0.001 # Prevent singularity
157
+
158
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
159
+ ramp_func = torch.clamp(linear_func, 0, 1)
160
+ return ramp_func
161
+
162
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
163
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
164
+ pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
165
+ inv_freq_extrapolation = 1.0 / pos_freqs
166
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
167
+
168
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
169
+
170
+ # Get n-dimensional rotational scaling corrected for extrapolation
171
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
172
+ inv_freq = (
173
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
174
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
175
+ )
176
+
177
+ return inv_freq, attention_factor
178
+
179
+
180
+ def _compute_longrope_parameters(
181
+ config: PretrainedConfig, seq_len: Optional[int] = None
182
+ ) -> Tuple["torch.Tensor", float]:
183
+ """
184
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
185
+ [original implementation](https://github.com/microsoft/LongRoPE)
186
+ Args:
187
+ config ([`~transformers.PretrainedConfig`]):
188
+ The model configuration.
189
+ seq_len (`int`, *optional*):
190
+ The current sequence length. Unused for this type of RoPE.
191
+ Returns:
192
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
193
+ post-processing scaling factor applied to the computed cos/sin.
194
+ """
195
+
196
+ base = config.rope_theta
197
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
198
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
199
+ dim = int(head_dim * partial_rotary_factor)
200
+ long_factor = config.rope_scaling["long_factor"]
201
+ short_factor = config.rope_scaling["short_factor"]
202
+ factor = config.rope_scaling.get("factor")
203
+ attention_factor = config.rope_scaling.get("attention_factor")
204
+
205
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
206
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
207
+ # values to compute the default attention scaling factor, instead of using `factor`.
208
+ if hasattr(config, "original_max_position_embeddings"):
209
+ max_position_embeddings = config.original_max_position_embeddings
210
+ expanded_max_position_embeddings = config.max_position_embeddings
211
+ factor = expanded_max_position_embeddings / max_position_embeddings
212
+ else:
213
+ max_position_embeddings = config.max_position_embeddings
214
+ expanded_max_position_embeddings = max_position_embeddings * factor
215
+
216
+ # Sets the attention factor as suggested in the paper
217
+ if attention_factor is None:
218
+ if factor <= 1.0:
219
+ attention_factor = 1.0
220
+ else:
221
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
222
+
223
+ # Compute the inverse frequencies -- scaled based on the target sequence length
224
+ if expanded_max_position_embeddings > max_position_embeddings:
225
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32)
226
+ else:
227
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32)
228
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
229
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
230
+
231
+ return inv_freq, attention_factor
232
+
233
+
234
+ def _compute_llama3_parameters(
235
+ config: PretrainedConfig, seq_len: Optional[int] = None
236
+ ) -> Tuple["torch.Tensor", float]:
237
+ """
238
+ Computes the inverse frequencies for llama 3.1.
239
+
240
+ Args:
241
+ config ([`~transformers.PretrainedConfig`]):
242
+ The model configuration.
243
+ seq_len (`int`, *optional*):
244
+ The current sequence length. Unused for this type of RoPE.
245
+ Returns:
246
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
247
+ post-processing scaling factor applied to the computed cos/sin.
248
+ """
249
+ # Gets the default RoPE parameters
250
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
251
+
252
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
253
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
254
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
255
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
256
+
257
+ low_freq_wavelen = old_context_len / low_freq_factor
258
+ high_freq_wavelen = old_context_len / high_freq_factor
259
+
260
+ wavelen = 2 * math.pi / inv_freq
261
+ # wavelen < high_freq_wavelen: do nothing
262
+ # wavelen > low_freq_wavelen: divide by factor
263
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
264
+ # otherwise: interpolate between the two, using a smooth factor
265
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
266
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
267
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
268
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
269
+
270
+ return inv_freq_llama, attention_factor
271
+
272
+
273
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
274
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
275
+ # parameterizations, as long as the callable has the same signature.
276
+ ROPE_INIT_FUNCTIONS = {
277
+ "default": _compute_default_rope_parameters,
278
+ "linear": _compute_linear_scaling_rope_parameters,
279
+ "dynamic": _compute_dynamic_ntk_parameters,
280
+ "yarn": _compute_yarn_parameters,
281
+ "longrope": _compute_longrope_parameters,
282
+ "llama3": _compute_llama3_parameters,
283
+ }
@@ -21,35 +21,84 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ from typing import TYPE_CHECKING
24
25
 
25
- from .auto import (
26
- RBLNAutoModel,
27
- RBLNAutoModelForAudioClassification,
28
- RBLNAutoModelForCausalLM,
29
- RBLNAutoModelForCTC,
30
- RBLNAutoModelForDepthEstimation,
31
- RBLNAutoModelForImageClassification,
32
- RBLNAutoModelForMaskedLM,
33
- RBLNAutoModelForQuestionAnswering,
34
- RBLNAutoModelForSeq2SeqLM,
35
- RBLNAutoModelForSequenceClassification,
36
- RBLNAutoModelForSpeechSeq2Seq,
37
- RBLNAutoModelForVision2Seq,
38
- )
39
- from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
40
- from .bert import RBLNBertModel
41
- from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
42
- from .dpt import RBLNDPTForDepthEstimation
43
- from .exaone import RBLNExaoneForCausalLM
44
- from .gemma import RBLNGemmaForCausalLM
45
- from .gpt2 import RBLNGPT2LMHeadModel
46
- from .llama import RBLNLlamaForCausalLM
47
- from .llava_next import RBLNLlavaNextForConditionalGeneration
48
- from .midm import RBLNMidmLMHeadModel
49
- from .mistral import RBLNMistralForCausalLM
50
- from .phi import RBLNPhiForCausalLM
51
- from .qwen2 import RBLNQwen2ForCausalLM
52
- from .t5 import RBLNT5ForConditionalGeneration
53
- from .wav2vec2 import RBLNWav2Vec2ForCTC
54
- from .whisper import RBLNWhisperForConditionalGeneration
55
- from .xlm_roberta import RBLNXLMRobertaModel
26
+ from transformers.utils import _LazyModule
27
+
28
+
29
+ _import_structure = {
30
+ "auto": [
31
+ "RBLNAutoModel",
32
+ "RBLNAutoModelForAudioClassification",
33
+ "RBLNAutoModelForCausalLM",
34
+ "RBLNAutoModelForCTC",
35
+ "RBLNAutoModelForDepthEstimation",
36
+ "RBLNAutoModelForImageClassification",
37
+ "RBLNAutoModelForMaskedLM",
38
+ "RBLNAutoModelForQuestionAnswering",
39
+ "RBLNAutoModelForSeq2SeqLM",
40
+ "RBLNAutoModelForSequenceClassification",
41
+ "RBLNAutoModelForSpeechSeq2Seq",
42
+ "RBLNAutoModelForVision2Seq",
43
+ ],
44
+ "bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
45
+ "bert": ["RBLNBertModel"],
46
+ "clip": ["RBLNCLIPTextModel", "RBLNCLIPTextModelWithProjection", "RBLNCLIPVisionModel"],
47
+ "dpt": ["RBLNDPTForDepthEstimation"],
48
+ "exaone": ["RBLNExaoneForCausalLM"],
49
+ "gemma": ["RBLNGemmaForCausalLM"],
50
+ "gpt2": ["RBLNGPT2LMHeadModel"],
51
+ "llama": ["RBLNLlamaForCausalLM"],
52
+ "llava_next": ["RBLNLlavaNextForConditionalGeneration"],
53
+ "midm": ["RBLNMidmLMHeadModel"],
54
+ "mistral": ["RBLNMistralForCausalLM"],
55
+ "phi": ["RBLNPhiForCausalLM"],
56
+ "qwen2": ["RBLNQwen2ForCausalLM"],
57
+ "t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
58
+ "wav2vec2": ["RBLNWav2Vec2ForCTC"],
59
+ "whisper": ["RBLNWhisperForConditionalGeneration"],
60
+ "xlm_roberta": ["RBLNXLMRobertaModel"],
61
+ }
62
+
63
+ if TYPE_CHECKING:
64
+ from .auto import (
65
+ RBLNAutoModel,
66
+ RBLNAutoModelForAudioClassification,
67
+ RBLNAutoModelForCausalLM,
68
+ RBLNAutoModelForCTC,
69
+ RBLNAutoModelForDepthEstimation,
70
+ RBLNAutoModelForImageClassification,
71
+ RBLNAutoModelForMaskedLM,
72
+ RBLNAutoModelForQuestionAnswering,
73
+ RBLNAutoModelForSeq2SeqLM,
74
+ RBLNAutoModelForSequenceClassification,
75
+ RBLNAutoModelForSpeechSeq2Seq,
76
+ RBLNAutoModelForVision2Seq,
77
+ )
78
+ from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
79
+ from .bert import RBLNBertModel
80
+ from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
81
+ from .dpt import RBLNDPTForDepthEstimation
82
+ from .exaone import RBLNExaoneForCausalLM
83
+ from .gemma import RBLNGemmaForCausalLM
84
+ from .gpt2 import RBLNGPT2LMHeadModel
85
+ from .llama import RBLNLlamaForCausalLM
86
+ from .llava_next import RBLNLlavaNextForConditionalGeneration
87
+ from .midm import RBLNMidmLMHeadModel
88
+ from .mistral import RBLNMistralForCausalLM
89
+ from .phi import RBLNPhiForCausalLM
90
+ from .qwen2 import RBLNQwen2ForCausalLM
91
+ from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
92
+ from .wav2vec2 import RBLNWav2Vec2ForCTC
93
+ from .whisper import RBLNWhisperForConditionalGeneration
94
+ from .xlm_roberta import RBLNXLMRobertaModel
95
+
96
+ else:
97
+ import sys
98
+
99
+ sys.modules[__name__] = _LazyModule(
100
+ __name__,
101
+ globals()["__file__"],
102
+ _import_structure,
103
+ module_spec=__spec__,
104
+ )
@@ -26,8 +26,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
28
  from transformers import (
29
- AutoConfig,
30
- AutoModel,
31
29
  CLIPTextConfig,
32
30
  CLIPTextModel,
33
31
  CLIPTextModelWithProjection,
@@ -39,6 +37,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
39
37
 
40
38
  from ....modeling_base import RBLNModel
41
39
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
40
+ from ....utils.context import override_auto_classes
42
41
 
43
42
 
44
43
  logger = logging.getLogger(__name__)
@@ -58,18 +57,14 @@ class _TextEncoder(torch.nn.Module):
58
57
 
59
58
 
60
59
  class RBLNCLIPTextModel(RBLNModel):
61
- original_model_class = CLIPTextModel
62
- original_config_class = CLIPTextConfig
63
-
64
60
  @classmethod
65
61
  def from_pretrained(cls, *args, **kwargs):
66
- configtmp = AutoConfig.from_pretrained
67
- modeltmp = AutoModel.from_pretrained
68
- AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
69
- AutoModel.from_pretrained = cls.original_model_class.from_pretrained
70
- rt = super().from_pretrained(*args, **kwargs)
71
- AutoConfig.from_pretrained = configtmp
72
- AutoModel.from_pretrained = modeltmp
62
+ with override_auto_classes(
63
+ config_func=CLIPTextConfig.from_pretrained,
64
+ model_func=CLIPTextModel.from_pretrained,
65
+ skip_taskmanager=False,
66
+ ):
67
+ rt = super().from_pretrained(*args, **kwargs)
73
68
  return rt
74
69
 
75
70
  @classmethod
@@ -133,18 +128,14 @@ class _VisionEncoder(torch.nn.Module):
133
128
 
134
129
 
135
130
  class RBLNCLIPVisionModel(RBLNModel):
136
- original_model_class = CLIPVisionModel
137
- original_config_class = CLIPVisionConfig
138
-
139
131
  @classmethod
140
132
  def from_pretrained(cls, *args, **kwargs):
141
- configtmp = AutoConfig.from_pretrained
142
- modeltmp = AutoModel.from_pretrained
143
- AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
144
- AutoModel.from_pretrained = cls.original_model_class.from_pretrained
145
- rt = super().from_pretrained(*args, **kwargs)
146
- AutoConfig.from_pretrained = configtmp
147
- AutoModel.from_pretrained = modeltmp
133
+ with override_auto_classes(
134
+ config_func=CLIPVisionConfig.from_pretrained,
135
+ model_func=CLIPVisionModel.from_pretrained,
136
+ skip_taskmanager=False,
137
+ ):
138
+ rt = super().from_pretrained(*args, **kwargs)
148
139
  return rt
149
140
 
150
141
  @classmethod
@@ -26,8 +26,6 @@ from .decoderonly_architecture import (
26
26
  DecoderOnlyDecoderLayer,
27
27
  DecoderOnlyModel,
28
28
  DecoderOnlyWrapper,
29
- DynamicNTKScalingRotaryEmbedding,
30
- LinearScalingRotaryEmbedding,
31
29
  RotaryEmbedding,
32
30
  apply_rotary_pos_emb,
33
31
  rotate_half,