optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,329 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ import copy
24
+ import importlib
25
+ from os import PathLike
26
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
27
+
28
+ import torch
29
+
30
+ from .modeling import RBLNModel
31
+ from .modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
32
+ from .utils.decorator_utils import remove_compile_time_kwargs
33
+
34
+
35
+ if TYPE_CHECKING:
36
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
37
+
38
+
39
+ class RBLNDiffusionMixin:
40
+ """
41
+ RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
42
+ This mixin class serves as a base for implementing RBLN-compatible Stable Diffusion pipelines. It contains shared logic for
43
+ handling the core components of Stable Diffusion.
44
+
45
+ To use this mixin:
46
+
47
+ 1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
48
+ 2. Define the required _submodules class variable listing the components to be compiled.
49
+ 3. If needed, implement get_default_rbln_config for custom configuration of submodules.
50
+
51
+ Example:
52
+ ```python
53
+ class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
54
+ _submodules = ["text_encoder", "unet", "vae"]
55
+
56
+ @classmethod
57
+ def get_default_rbln_config(cls, model, submodule_name, rbln_config):
58
+ # Configuration for other submodules...
59
+ pass
60
+ ```
61
+
62
+ Class Variables:
63
+ _submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
64
+
65
+ Methods:
66
+ from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
67
+
68
+ Notes:
69
+ - When `export=True`, all compatible submodules will be compiled for NPU inference
70
+ - The compilation config can be customized per submodule by including submodule names
71
+ as keys in rbln_config
72
+ """
73
+
74
+ _submodules = []
75
+
76
+ @classmethod
77
+ @property
78
+ def img2img_pipeline(cls):
79
+ return "Img2Img" in cls.__name__
80
+
81
+ @classmethod
82
+ @property
83
+ def inpaint_pipeline(cls):
84
+ return "Inpaint" in cls.__name__
85
+
86
+ @classmethod
87
+ def get_submodule_rbln_config(
88
+ cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
89
+ ) -> Dict[str, Any]:
90
+ submodule = getattr(model, submodule_name)
91
+ submodule_class_name = submodule.__class__.__name__
92
+
93
+ if submodule_class_name == "MultiControlNetModel":
94
+ submodule_class_name = "ControlNetModel"
95
+
96
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
97
+
98
+ submodule_config = rbln_config.get(submodule_name, {})
99
+ submodule_config = copy.deepcopy(submodule_config)
100
+
101
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
102
+
103
+ submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
104
+ submodule_config.update(
105
+ {
106
+ "img2img_pipeline": cls.img2img_pipeline,
107
+ "inpaint_pipeline": cls.inpaint_pipeline,
108
+ }
109
+ )
110
+ submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
111
+ return submodule_config
112
+
113
+ @staticmethod
114
+ def _maybe_apply_and_fuse_lora(
115
+ model: torch.nn.Module,
116
+ lora_ids: Optional[Union[str, List[str]]] = None,
117
+ lora_weights_names: Optional[Union[str, List[str]]] = None,
118
+ lora_scales: Optional[Union[float, List[float]]] = None,
119
+ ) -> torch.nn.Module:
120
+ lora_ids = [lora_ids] if isinstance(lora_ids, str) else lora_ids
121
+ lora_weights_names = [lora_weights_names] if isinstance(lora_weights_names, str) else lora_weights_names
122
+ lora_scales = [lora_scales] if isinstance(lora_scales, float) else lora_scales
123
+
124
+ # adapt lora weight into pipeline before compilation
125
+ if lora_ids and lora_weights_names:
126
+ if len(lora_ids) == 1:
127
+ if len(lora_ids) != len(lora_weights_names):
128
+ raise ValueError(
129
+ f"You must define the same number of lora ids ({len(lora_ids)} and lora weights ({len(lora_weights_names)}))"
130
+ )
131
+ else:
132
+ model.load_lora_weights(lora_ids[0], weight_name=lora_weights_names[0])
133
+ model.fuse_lora(lora_scale=lora_scales[0] if lora_scales else 1.0)
134
+ elif len(lora_ids) > 1:
135
+ if not len(lora_ids) == len(lora_weights_names):
136
+ raise ValueError(
137
+ f"If you fuse {len(lora_ids)} lora models, but you must define the same number for lora weights and adapters."
138
+ )
139
+
140
+ adapter_names = [f"adapter_{i}" for i in range(len(lora_ids))]
141
+
142
+ for lora_id, lora_weight, adapter_name in zip(lora_ids, lora_weights_names, adapter_names):
143
+ model.load_lora_weights(lora_id, weight_name=lora_weight, adapter_name=adapter_name)
144
+
145
+ if lora_scales:
146
+ model.set_adapters(adapter_names, adapter_weights=lora_scales)
147
+
148
+ model.fuse_lora()
149
+ return model
150
+
151
+ @classmethod
152
+ @use_rbln_config
153
+ def from_pretrained(
154
+ cls,
155
+ model_id: str,
156
+ *,
157
+ export: bool = False,
158
+ model_save_dir: Optional[PathLike] = None,
159
+ rbln_config: Dict[str, Any] = {},
160
+ lora_ids: Optional[Union[str, List[str]]] = None,
161
+ lora_weights_names: Optional[Union[str, List[str]]] = None,
162
+ lora_scales: Optional[Union[float, List[float]]] = None,
163
+ **kwargs,
164
+ ) -> RBLNModel:
165
+ if export:
166
+ # keep submodules if user passed any of them.
167
+ passed_submodules = {
168
+ name: kwargs.pop(name) for name in cls._submodules if isinstance(kwargs.get(name), RBLNModel)
169
+ }
170
+
171
+ else:
172
+ # raise error if any of submodules are torch module.
173
+ model_index_config = None
174
+ for submodule_name in cls._submodules:
175
+ if isinstance(kwargs.get(submodule_name), torch.nn.Module):
176
+ raise AssertionError(
177
+ f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
178
+ )
179
+
180
+ # Load submodule outside if runtime kwargs(e.g. device) is specified.
181
+ if submodule_config := rbln_config.get(submodule_name):
182
+ if any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
183
+ if model_index_config is None:
184
+ model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
185
+
186
+ module_name, class_name = model_index_config[submodule_name]
187
+ if module_name != "optimum.rbln":
188
+ raise ValueError(
189
+ f"Invalid module_name '{module_name}' found in model_index.json for "
190
+ f"submodule '{submodule_name}'. "
191
+ "Expected 'optimum.rbln'. Please check the model_index.json configuration."
192
+ )
193
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
194
+ submodule = submodule_cls.from_pretrained(
195
+ model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
196
+ )
197
+ kwargs[submodule_name] = submodule
198
+
199
+ with ContextRblnConfig(
200
+ device=rbln_config.get("device"),
201
+ device_map=rbln_config.get("device_map"),
202
+ create_runtimes=rbln_config.get("create_runtimes"),
203
+ optimize_host_mem=rbln_config.get("optimize_host_memory"),
204
+ ):
205
+ model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
206
+
207
+ if not export:
208
+ return model
209
+
210
+ model = cls._maybe_apply_and_fuse_lora(
211
+ model,
212
+ lora_ids=lora_ids,
213
+ lora_weights_names=lora_weights_names,
214
+ lora_scales=lora_scales,
215
+ )
216
+
217
+ compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
218
+ return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
219
+
220
+ @classmethod
221
+ def _compile_submodules(
222
+ cls,
223
+ model: torch.nn.Module,
224
+ passed_submodules: Dict[str, RBLNModel],
225
+ model_save_dir: Optional[PathLike],
226
+ rbln_config: Dict[str, Any],
227
+ ) -> Dict[str, RBLNModel]:
228
+ compiled_submodules = {}
229
+
230
+ for submodule_name in cls._submodules:
231
+ submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
232
+ submodule_rbln_config = cls.get_submodule_rbln_config(model, submodule_name, rbln_config)
233
+
234
+ if submodule is None:
235
+ raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
236
+ elif isinstance(submodule, RBLNModel):
237
+ pass
238
+ elif submodule_name == "controlnet" and hasattr(submodule, "nets"):
239
+ # In case of multicontrolnet
240
+ submodule = cls._compile_multicontrolnet(
241
+ controlnets=submodule,
242
+ model_save_dir=model_save_dir,
243
+ controlnet_rbln_config=submodule_rbln_config,
244
+ )
245
+ elif isinstance(submodule, torch.nn.Module):
246
+ submodule_cls: RBLNModel = getattr(
247
+ importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
248
+ )
249
+ submodule = submodule_cls.from_model(
250
+ model=submodule,
251
+ subfolder=submodule_name,
252
+ model_save_dir=model_save_dir,
253
+ rbln_config=submodule_rbln_config,
254
+ )
255
+ else:
256
+ raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
257
+
258
+ compiled_submodules[submodule_name] = submodule
259
+ return compiled_submodules
260
+
261
+ @classmethod
262
+ def _compile_multicontrolnet(
263
+ cls,
264
+ controlnets: "MultiControlNetModel",
265
+ model_save_dir: Optional[PathLike],
266
+ controlnet_rbln_config: Dict[str, Any],
267
+ ):
268
+ # Compile multiple ControlNet models for a MultiControlNet setup
269
+ from .diffusers.models.controlnet import RBLNControlNetModel
270
+ from .diffusers.pipelines.controlnet import RBLNMultiControlNetModel
271
+
272
+ compiled_controlnets = [
273
+ RBLNControlNetModel.from_model(
274
+ model=controlnet,
275
+ subfolder="controlnet" if i == 0 else f"controlnet_{i}",
276
+ model_save_dir=model_save_dir,
277
+ rbln_config=controlnet_rbln_config,
278
+ )
279
+ for i, controlnet in enumerate(controlnets.nets)
280
+ ]
281
+ return RBLNMultiControlNetModel(compiled_controlnets, config=controlnets.nets[0].config)
282
+
283
+ @classmethod
284
+ def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
285
+ # Construct finalize pipe setup with compiled submodules and configurations
286
+
287
+ if model_save_dir is not None:
288
+ # To skip saving original pytorch modules
289
+ for submodule_name in cls._submodules:
290
+ delattr(model, submodule_name)
291
+
292
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
293
+ # So config must be saved again, later.
294
+ model.save_pretrained(model_save_dir)
295
+ # FIXME: Here, model touches its submodules such as model.unet,
296
+ # Causing warning messeages.
297
+
298
+ update_dict = {}
299
+ for submodule_name in cls._submodules:
300
+ # replace submodule
301
+ setattr(model, submodule_name, submodules[submodule_name])
302
+ update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
303
+
304
+ # Update config to be able to load from model directory.
305
+ #
306
+ # e.g)
307
+ # update_dict = {
308
+ # "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
309
+ # "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
310
+ # "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
311
+ # }
312
+ model.register_to_config(**update_dict)
313
+
314
+ if model_save_dir:
315
+ # overwrite to replace incorrect config
316
+ model.save_config(model_save_dir)
317
+
318
+ if rbln_config.get("optimize_host_memory") is False:
319
+ # Keep compiled_model objs to further analysis. -> TODO: remove soon...
320
+ model.compiled_models = []
321
+ for name in cls._submodules:
322
+ submodule = getattr(model, name)
323
+ model.compiled_models.extend(submodule.compiled_models)
324
+
325
+ return model
326
+
327
+ @remove_compile_time_kwargs
328
+ def __call__(self, *args, **kwargs):
329
+ return super().__call__(*args, **kwargs)
@@ -28,7 +28,6 @@ from transformers.utils import _LazyModule
28
28
 
29
29
  _import_structure = {
30
30
  "cache_utils": ["RebelDynamicCache"],
31
- "generation": ["BatchTextIteratorStreamer"],
32
31
  "models": [
33
32
  "RBLNAutoModel",
34
33
  "RBLNAutoModelForAudioClassification",
@@ -57,6 +56,7 @@ _import_structure = {
57
56
  "RBLNWhisperForConditionalGeneration",
58
57
  "RBLNLlamaForCausalLM",
59
58
  "RBLNPhiForCausalLM",
59
+ "RBLNT5EncoderModel",
60
60
  "RBLNT5ForConditionalGeneration",
61
61
  "RBLNLlavaNextForConditionalGeneration",
62
62
  "RBLNMidmLMHeadModel",
@@ -67,7 +67,6 @@ _import_structure = {
67
67
 
68
68
  if TYPE_CHECKING:
69
69
  from .cache_utils import RebelDynamicCache
70
- from .generation import BatchTextIteratorStreamer
71
70
  from .models import (
72
71
  RBLNAutoModel,
73
72
  RBLNAutoModelForAudioClassification,
@@ -97,6 +96,7 @@ if TYPE_CHECKING:
97
96
  RBLNMistralForCausalLM,
98
97
  RBLNPhiForCausalLM,
99
98
  RBLNQwen2ForCausalLM,
99
+ RBLNT5EncoderModel,
100
100
  RBLNT5ForConditionalGeneration,
101
101
  RBLNWav2Vec2ForCTC,
102
102
  RBLNWhisperForConditionalGeneration,
@@ -12,9 +12,11 @@ class RebelDynamicCache(DynamicCache):
12
12
  `[batch_size, num_heads, seq_len, head_dim]`.
13
13
  """
14
14
 
15
- def __init__(self, current_steps) -> None:
15
+ def __init__(self, position_ids) -> None:
16
16
  super().__init__()
17
- self.current_steps = current_steps
17
+ # batch, _ = position_ids.shape
18
+ # current_steps = [position_ids[b][0] for b in range(batch)]
19
+ self.current_steps = position_ids[:, 0]
18
20
 
19
21
  def assign(
20
22
  self,
@@ -58,13 +60,7 @@ class RebelDynamicCache(DynamicCache):
58
60
  @classmethod
59
61
  def from_input_format(cls, position_ids, num_hidden_layer, *past_key_values) -> "DynamicCache":
60
62
  """Converts a cache in the rbln cache format (list of past_kv) into an equivalent `DynamicCache`."""
61
-
62
- batch, _ = position_ids.shape
63
- current_steps = [position_ids[b][0] for b in range(batch)]
64
-
65
- assert len(current_steps) == batch
66
- cache = cls(current_steps)
67
-
63
+ cache = cls(position_ids)
68
64
  for layer_idx in range(num_hidden_layer):
69
65
  key_states = past_key_values[layer_idx * 2]
70
66
  value_states = past_key_values[layer_idx * 2 + 1]
@@ -0,0 +1,283 @@
1
+ import math
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ def _compute_default_rope_parameters(
9
+ config: Optional[PretrainedConfig] = None,
10
+ seq_len: Optional[int] = None,
11
+ ) -> Tuple["torch.Tensor", float]:
12
+ """
13
+ Computes the inverse frequencies according to the original RoPE implementation
14
+ Args:
15
+ config ([`~transformers.PretrainedConfig`]):
16
+ The model configuration.
17
+ seq_len (`int`, *optional*):
18
+ The current sequence length. Unused for this type of RoPE.
19
+ Returns:
20
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
21
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
22
+ """
23
+
24
+ base = config.rope_theta
25
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
26
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
27
+ dim = int(head_dim * partial_rotary_factor)
28
+
29
+ attention_factor = 1.0 # Unused in this type of RoPE
30
+
31
+ # Compute the inverse frequencies
32
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
33
+ return inv_freq, attention_factor
34
+
35
+
36
+ def _compute_linear_scaling_rope_parameters(
37
+ config: Optional[PretrainedConfig] = None,
38
+ seq_len: Optional[int] = None,
39
+ ) -> Tuple["torch.Tensor", float]:
40
+ """
41
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
42
+ Args:
43
+ config ([`~transformers.PretrainedConfig`]):
44
+ The model configuration.
45
+ seq_len (`int`, *optional*):
46
+ The current sequence length. Unused for this type of RoPE.
47
+ Returns:
48
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
49
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
50
+ """
51
+
52
+ factor = config.rope_scaling["factor"]
53
+
54
+ # Gets the default RoPE parameters
55
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
56
+
57
+ # Then applies linear scaling to the frequencies.
58
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
59
+ # applying scaling to the inverse frequencies is equivalent.
60
+ inv_freq /= factor
61
+ return inv_freq, attention_factor
62
+
63
+
64
+ def _compute_dynamic_ntk_parameters(
65
+ config: Optional[PretrainedConfig] = None,
66
+ seq_len: Optional[int] = None,
67
+ ) -> Tuple["torch.Tensor", float]:
68
+ """
69
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
70
+ Args:
71
+ config ([`~transformers.PretrainedConfig`]):
72
+ The model configuration.
73
+ seq_len (`int`, *optional*):
74
+ The current sequence length, used to update the dynamic RoPE at inference time.
75
+ Returns:
76
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
77
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
78
+ """
79
+
80
+ base = config.rope_theta
81
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
82
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
83
+ dim = int(head_dim * partial_rotary_factor)
84
+ max_position_embeddings = config.max_position_embeddings
85
+ factor = config.rope_scaling["factor"]
86
+
87
+ attention_factor = 1.0 # Unused in this type of RoPE
88
+
89
+ # Process with chunk_size to reduce precesion error
90
+ chunk_size = 4096
91
+ chunks = (seq_len + chunk_size - 1) // chunk_size
92
+
93
+ inv_freq_list = []
94
+ for i in range(chunks):
95
+ start = i * chunk_size
96
+ end = min((i + 1) * chunk_size, seq_len)
97
+
98
+ seq_lens = torch.arange(start, end, dtype=torch.float32).view(-1, 1) + 1.0
99
+ seq_lens = torch.where(seq_lens > max_position_embeddings, seq_lens, max_position_embeddings)
100
+
101
+ # Compute the inverse frequencies for each chunk
102
+ scaled_base = base * ((factor * seq_lens / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
103
+ inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
104
+
105
+ inv_freq_list.append(inv_freq)
106
+
107
+ final_inv_freq = torch.cat(inv_freq_list, dim=0)
108
+
109
+ return final_inv_freq, attention_factor
110
+
111
+
112
+ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
113
+ """
114
+ Computes the inverse frequencies with NTK scaling. Please refer to the
115
+ [original paper](https://arxiv.org/abs/2309.00071)
116
+ Args:
117
+ config ([`~transformers.PretrainedConfig`]):
118
+ The model configuration.
119
+ seq_len (`int`, *optional*):
120
+ The current sequence length. Unused for this type of RoPE.
121
+ Returns:
122
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
123
+ post-processing scaling factor applied to the computed cos/sin.
124
+ """
125
+
126
+ base = config.rope_theta
127
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
128
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
129
+ dim = int(head_dim * partial_rotary_factor)
130
+ max_position_embeddings = config.max_position_embeddings
131
+ factor = config.rope_scaling["factor"]
132
+
133
+ # Sets the attention factor as suggested in the paper
134
+ attention_factor = config.rope_scaling.get("attention_factor")
135
+ if attention_factor is None:
136
+ attention_factor = 0.1 * math.log(factor) + 1.0
137
+
138
+ # Optional config options
139
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
140
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
141
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
142
+
143
+ # Compute the inverse frequencies
144
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
145
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
146
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
147
+
148
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
149
+ """Find dimension range bounds based on rotations"""
150
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
151
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
152
+ return max(low, 0), min(high, dim - 1)
153
+
154
+ def linear_ramp_factor(min, max, dim):
155
+ if min == max:
156
+ max += 0.001 # Prevent singularity
157
+
158
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
159
+ ramp_func = torch.clamp(linear_func, 0, 1)
160
+ return ramp_func
161
+
162
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
163
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
164
+ pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
165
+ inv_freq_extrapolation = 1.0 / pos_freqs
166
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
167
+
168
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
169
+
170
+ # Get n-dimensional rotational scaling corrected for extrapolation
171
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
172
+ inv_freq = (
173
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
174
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
175
+ )
176
+
177
+ return inv_freq, attention_factor
178
+
179
+
180
+ def _compute_longrope_parameters(
181
+ config: PretrainedConfig, seq_len: Optional[int] = None
182
+ ) -> Tuple["torch.Tensor", float]:
183
+ """
184
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
185
+ [original implementation](https://github.com/microsoft/LongRoPE)
186
+ Args:
187
+ config ([`~transformers.PretrainedConfig`]):
188
+ The model configuration.
189
+ seq_len (`int`, *optional*):
190
+ The current sequence length. Unused for this type of RoPE.
191
+ Returns:
192
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
193
+ post-processing scaling factor applied to the computed cos/sin.
194
+ """
195
+
196
+ base = config.rope_theta
197
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
198
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
199
+ dim = int(head_dim * partial_rotary_factor)
200
+ long_factor = config.rope_scaling["long_factor"]
201
+ short_factor = config.rope_scaling["short_factor"]
202
+ factor = config.rope_scaling.get("factor")
203
+ attention_factor = config.rope_scaling.get("attention_factor")
204
+
205
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
206
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
207
+ # values to compute the default attention scaling factor, instead of using `factor`.
208
+ if hasattr(config, "original_max_position_embeddings"):
209
+ max_position_embeddings = config.original_max_position_embeddings
210
+ expanded_max_position_embeddings = config.max_position_embeddings
211
+ factor = expanded_max_position_embeddings / max_position_embeddings
212
+ else:
213
+ max_position_embeddings = config.max_position_embeddings
214
+ expanded_max_position_embeddings = max_position_embeddings * factor
215
+
216
+ # Sets the attention factor as suggested in the paper
217
+ if attention_factor is None:
218
+ if factor <= 1.0:
219
+ attention_factor = 1.0
220
+ else:
221
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
222
+
223
+ # Compute the inverse frequencies -- scaled based on the target sequence length
224
+ if expanded_max_position_embeddings > max_position_embeddings:
225
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32)
226
+ else:
227
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32)
228
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
229
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
230
+
231
+ return inv_freq, attention_factor
232
+
233
+
234
+ def _compute_llama3_parameters(
235
+ config: PretrainedConfig, seq_len: Optional[int] = None
236
+ ) -> Tuple["torch.Tensor", float]:
237
+ """
238
+ Computes the inverse frequencies for llama 3.1.
239
+
240
+ Args:
241
+ config ([`~transformers.PretrainedConfig`]):
242
+ The model configuration.
243
+ seq_len (`int`, *optional*):
244
+ The current sequence length. Unused for this type of RoPE.
245
+ Returns:
246
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
247
+ post-processing scaling factor applied to the computed cos/sin.
248
+ """
249
+ # Gets the default RoPE parameters
250
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
251
+
252
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
253
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
254
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
255
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
256
+
257
+ low_freq_wavelen = old_context_len / low_freq_factor
258
+ high_freq_wavelen = old_context_len / high_freq_factor
259
+
260
+ wavelen = 2 * math.pi / inv_freq
261
+ # wavelen < high_freq_wavelen: do nothing
262
+ # wavelen > low_freq_wavelen: divide by factor
263
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
264
+ # otherwise: interpolate between the two, using a smooth factor
265
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
266
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
267
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
268
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
269
+
270
+ return inv_freq_llama, attention_factor
271
+
272
+
273
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
274
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
275
+ # parameterizations, as long as the callable has the same signature.
276
+ ROPE_INIT_FUNCTIONS = {
277
+ "default": _compute_default_rope_parameters,
278
+ "linear": _compute_linear_scaling_rope_parameters,
279
+ "dynamic": _compute_dynamic_ntk_parameters,
280
+ "yarn": _compute_yarn_parameters,
281
+ "longrope": _compute_longrope_parameters,
282
+ "llama3": _compute_llama3_parameters,
283
+ }