optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +5 -1
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
  4. optimum/rbln/diffusers/models/controlnet.py +36 -56
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
  16. optimum/rbln/modeling_base.py +12 -5
  17. optimum/rbln/modeling_diffusers.py +400 -0
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/cache_utils.py +5 -9
  20. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  21. optimum/rbln/transformers/models/__init__.py +80 -31
  22. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
  23. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  25. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
  26. optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
  27. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  29. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  30. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  31. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  32. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
  33. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
  35. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  36. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  37. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  38. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  39. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  40. optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  42. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  43. optimum/rbln/utils/context.py +58 -0
  44. optimum/rbln/utils/decorator_utils.py +55 -0
  45. optimum/rbln/utils/import_utils.py +7 -0
  46. optimum/rbln/utils/runtime_utils.py +4 -4
  47. optimum/rbln/utils/timer_utils.py +2 -2
  48. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
  49. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
  50. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,400 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ import importlib
24
+ from os import PathLike
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
26
+
27
+ import torch
28
+
29
+ from .modeling_base import RBLNModel
30
+ from .modeling_config import ContextRblnConfig, use_rbln_config
31
+ from .utils.decorator_utils import remove_compile_time_kwargs
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
36
+
37
+
38
+ class RBLNDiffusionMixin:
39
+ """
40
+ RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
41
+ This mixin class serves as a base for implementing RBLN-compatible Stable Diffusion pipelines. It contains shared logic for
42
+ handling the core components of Stable Diffusion.
43
+
44
+ To use this mixin:
45
+
46
+ 1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
47
+ 2. Define the required _submodules class variable listing the components to be compiled.
48
+ 3. If needed, implement get_default_rbln_config for custom configuration of submodules.
49
+
50
+ Example:
51
+ ```python
52
+ class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
53
+ _submodules = ["text_encoder", "unet", "vae"]
54
+
55
+ @classmethod
56
+ def get_default_rbln_config(cls, model, submodule_name, rbln_config):
57
+ # Configuration for other submodules...
58
+ pass
59
+ ```
60
+
61
+ Class Variables:
62
+ _submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
63
+
64
+ Methods:
65
+ from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
66
+
67
+ Notes:
68
+ - When `export=True`, all compatible submodules will be compiled for NPU inference
69
+ - The compilation config can be customized per submodule by including submodule names
70
+ as keys in rbln_config
71
+ """
72
+
73
+ _submodules = []
74
+
75
+ @classmethod
76
+ @property
77
+ def use_encode(cls):
78
+ return "Img2Img" in cls.__name__
79
+
80
+ @classmethod
81
+ def _get_unet_batch_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> int:
82
+ # Calculates the batch size based on guidance scale
83
+ batch_size = rbln_config.get("batch_size", 1)
84
+ do_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
85
+ return batch_size * 2 if do_guidance else batch_size
86
+
87
+ @classmethod
88
+ def _get_vae_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
89
+ image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
90
+ if (image_size[0] is None) != (image_size[1] is None):
91
+ raise ValueError("Both image height and image width must be given or not given")
92
+ elif image_size[0] is None and image_size[1] is None:
93
+ if cls.use_encode:
94
+ sample_size = model.vae.config.sample_size
95
+ else:
96
+ # In case of text2img, sample size of vae decoder is determined by unet.
97
+ unet_sample_size = model.unet.config.sample_size
98
+ if isinstance(unet_sample_size, int):
99
+ sample_size = unet_sample_size * model.vae_scale_factor
100
+ else:
101
+ sample_size = (
102
+ unet_sample_size[0] * model.vae_scale_factor,
103
+ unet_sample_size[1] * model.vae_scale_factor,
104
+ )
105
+
106
+ else:
107
+ sample_size = (image_size[0], image_size[1])
108
+ return sample_size
109
+
110
+ @classmethod
111
+ def _get_unet_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
112
+ image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
113
+ if (image_size[0] is None) != (image_size[1] is None):
114
+ raise ValueError("Both image height and image width must be given or not given")
115
+ elif image_size[0] is None and image_size[1] is None:
116
+ if cls.use_encode:
117
+ # In case of img2img, sample size of unet is determined by vae encoder.
118
+ vae_sample_size = model.vae.config.sample_size
119
+ if isinstance(vae_sample_size, int):
120
+ sample_size = vae_sample_size // model.vae_scale_factor
121
+ else:
122
+ sample_size = (
123
+ vae_sample_size[0] // model.vae_scale_factor,
124
+ vae_sample_size[1] // model.vae_scale_factor,
125
+ )
126
+ else:
127
+ sample_size = model.unet.config.sample_size
128
+ else:
129
+ sample_size = (image_size[0] // model.vae_scale_factor, image_size[1] // model.vae_scale_factor)
130
+ return sample_size
131
+
132
+ @classmethod
133
+ def _get_default_config(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
134
+ # default configurations for each submodules
135
+ return {"img2img_pipeline": cls.use_encode}
136
+
137
+ @classmethod
138
+ def get_default_rbln_config_text_encoder(
139
+ cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
140
+ ) -> Dict[str, Any]:
141
+ batch_size = rbln_config.get("batch_size", 1)
142
+ return {"batch_size": batch_size}
143
+
144
+ @classmethod
145
+ def get_default_rbln_config_text_encoder_2(
146
+ cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
147
+ ) -> Dict[str, Any]:
148
+ batch_size = rbln_config.get("batch_size", 1)
149
+ return {"batch_size": batch_size}
150
+
151
+ @classmethod
152
+ def get_default_rbln_config_unet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
153
+ # configuration for unet
154
+ unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
155
+ text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
156
+ return {
157
+ **cls._get_default_config(model, rbln_config),
158
+ "max_seq_len": model.text_encoder.config.max_position_embeddings,
159
+ "text_model_hidden_size": text_model_hidden_size,
160
+ "batch_size": unet_batch_size,
161
+ "sample_size": cls._get_unet_sample_size(model, rbln_config),
162
+ "is_controlnet": "controlnet" in model.config.keys(),
163
+ }
164
+
165
+ @classmethod
166
+ def get_default_rbln_config_vae(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
167
+ # configuration for vae
168
+ batch_size = rbln_config.get("batch_size", 1)
169
+ return {
170
+ **cls._get_default_config(model, rbln_config),
171
+ "sample_size": cls._get_vae_sample_size(model, rbln_config),
172
+ "batch_size": batch_size,
173
+ }
174
+
175
+ @classmethod
176
+ def get_default_rbln_config_controlnet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
177
+ # configuration for controlnet
178
+ unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
179
+ text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
180
+ return {
181
+ **cls._get_default_config(model, rbln_config),
182
+ "max_seq_len": model.text_encoder.config.max_position_embeddings,
183
+ "vae_sample_size": cls._get_vae_sample_size(model, rbln_config),
184
+ "unet_sample_size": cls._get_unet_sample_size(model, rbln_config),
185
+ "batch_size": unet_batch_size,
186
+ "text_model_hidden_size": text_model_hidden_size,
187
+ }
188
+
189
+ @classmethod
190
+ def get_default_rbln_config(
191
+ cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
192
+ ) -> Dict[str, Any]:
193
+ # Returns the default configuration based on submodule name
194
+ config_method = f"get_default_rbln_config_{submodule_name}"
195
+ if hasattr(cls, config_method):
196
+ return getattr(cls, config_method)(model, rbln_config)
197
+ raise ValueError(f"Unknown submodule: {submodule_name}")
198
+
199
+ @staticmethod
200
+ def _maybe_apply_and_fuse_lora(
201
+ model: torch.nn.Module,
202
+ lora_ids: Optional[Union[str, List[str]]] = None,
203
+ lora_weights_names: Optional[Union[str, List[str]]] = None,
204
+ lora_scales: Optional[Union[float, List[float]]] = None,
205
+ ) -> torch.nn.Module:
206
+ lora_ids = [lora_ids] if isinstance(lora_ids, str) else lora_ids
207
+ lora_weights_names = [lora_weights_names] if isinstance(lora_weights_names, str) else lora_weights_names
208
+ lora_scales = [lora_scales] if isinstance(lora_scales, float) else lora_scales
209
+
210
+ # adapt lora weight into pipeline before compilation
211
+ if lora_ids and lora_weights_names:
212
+ if len(lora_ids) == 1:
213
+ if len(lora_ids) != len(lora_weights_names):
214
+ raise ValueError(
215
+ f"You must define the same number of lora ids ({len(lora_ids)} and lora weights ({len(lora_weights_names)}))"
216
+ )
217
+ else:
218
+ model.load_lora_weights(lora_ids[0], weight_name=lora_weights_names[0])
219
+ model.fuse_lora(lora_scale=lora_scales[0] if lora_scales else 1.0)
220
+ elif len(lora_ids) > 1:
221
+ if not len(lora_ids) == len(lora_weights_names):
222
+ raise ValueError(
223
+ f"If you fuse {len(lora_ids)} lora models, but you must define the same number for lora weights and adapters."
224
+ )
225
+
226
+ adapter_names = [f"adapter_{i}" for i in range(len(lora_ids))]
227
+
228
+ for lora_id, lora_weight, adapter_name in zip(lora_ids, lora_weights_names, adapter_names):
229
+ model.load_lora_weights(lora_id, weight_name=lora_weight, adapter_name=adapter_name)
230
+
231
+ if lora_scales:
232
+ model.set_adapters(adapter_names, adapter_weights=lora_scales)
233
+
234
+ model.fuse_lora()
235
+ return model
236
+
237
+ @classmethod
238
+ @use_rbln_config
239
+ def from_pretrained(
240
+ cls,
241
+ model_id: str,
242
+ *,
243
+ export: bool = False,
244
+ model_save_dir: Optional[PathLike] = None,
245
+ rbln_config: Dict[str, Any] = {},
246
+ lora_ids: Optional[Union[str, List[str]]] = None,
247
+ lora_weights_names: Optional[Union[str, List[str]]] = None,
248
+ lora_scales: Optional[Union[float, List[float]]] = None,
249
+ **kwargs,
250
+ ) -> RBLNModel:
251
+ if export:
252
+ # keep submodules if user passed any of them.
253
+ passed_submodules = {
254
+ name: kwargs.pop(name) for name in cls._submodules if isinstance(kwargs.get(name), RBLNModel)
255
+ }
256
+
257
+ else:
258
+ # raise error if any of submodules are torch module.
259
+ for name in cls._submodules:
260
+ if isinstance(kwargs.get(name), torch.nn.Module):
261
+ raise AssertionError(
262
+ f"{name} is not compiled torch module. If you want to compile, set `export=True`."
263
+ )
264
+
265
+ with ContextRblnConfig(
266
+ device=rbln_config.get("device"),
267
+ device_map=rbln_config.get("device_map"),
268
+ create_runtimes=rbln_config.get("create_runtimes"),
269
+ optimize_host_mem=rbln_config.get("optimize_host_memory"),
270
+ ):
271
+ model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
272
+
273
+ if not export:
274
+ return model
275
+
276
+ model = cls._maybe_apply_and_fuse_lora(
277
+ model,
278
+ lora_ids=lora_ids,
279
+ lora_weights_names=lora_weights_names,
280
+ lora_scales=lora_scales,
281
+ )
282
+
283
+ compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
284
+ return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
285
+
286
+ @classmethod
287
+ def _compile_submodules(
288
+ cls,
289
+ model: torch.nn.Module,
290
+ passed_submodules: Dict[str, RBLNModel],
291
+ model_save_dir: Optional[PathLike],
292
+ rbln_config: Dict[str, Any],
293
+ ) -> Dict[str, RBLNModel]:
294
+ # Compile submodules based on rbln_config
295
+ compiled_submodules = {}
296
+
297
+ # FIXME : Currently, optimum-rbln for transformer does not use base rbln config.
298
+ base_rbln_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
299
+ for submodule_name in cls._submodules:
300
+ submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
301
+ submodule_rbln_config = cls.get_default_rbln_config(model, submodule_name, rbln_config)
302
+ submodule_rbln_config.update(base_rbln_config)
303
+ submodule_rbln_config.update(rbln_config.get(submodule_name, {}))
304
+
305
+ if submodule is None:
306
+ raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
307
+ elif isinstance(submodule, RBLNModel):
308
+ pass
309
+ elif submodule_name == "controlnet" and hasattr(submodule, "nets"):
310
+ # In case of multicontrolnet
311
+ submodule = cls._compile_multicontrolnet(
312
+ controlnets=submodule,
313
+ model_save_dir=model_save_dir,
314
+ controlnet_rbln_config=submodule_rbln_config,
315
+ )
316
+ elif isinstance(submodule, torch.nn.Module):
317
+ submodule_cls: RBLNModel = getattr(
318
+ importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
319
+ )
320
+ submodule = submodule_cls.from_model(
321
+ model=submodule,
322
+ subfolder=submodule_name,
323
+ model_save_dir=model_save_dir,
324
+ rbln_config=submodule_rbln_config,
325
+ )
326
+ else:
327
+ raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
328
+
329
+ compiled_submodules[submodule_name] = submodule
330
+ return compiled_submodules
331
+
332
+ @classmethod
333
+ def _compile_multicontrolnet(
334
+ cls,
335
+ controlnets: "MultiControlNetModel",
336
+ model_save_dir: Optional[PathLike],
337
+ controlnet_rbln_config: Dict[str, Any],
338
+ ):
339
+ # Compile multiple ControlNet models for a MultiControlNet setup
340
+ from .diffusers.models.controlnet import RBLNControlNetModel
341
+ from .diffusers.pipelines.controlnet import RBLNMultiControlNetModel
342
+
343
+ compiled_controlnets = [
344
+ RBLNControlNetModel.from_model(
345
+ model=controlnet,
346
+ subfolder="controlnet" if i == 0 else f"controlnet_{i}",
347
+ model_save_dir=model_save_dir,
348
+ rbln_config=controlnet_rbln_config,
349
+ )
350
+ for i, controlnet in enumerate(controlnets.nets)
351
+ ]
352
+ return RBLNMultiControlNetModel(compiled_controlnets, config=controlnets.nets[0].config)
353
+
354
+ @classmethod
355
+ def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
356
+ # Construct finalize pipe setup with compiled submodules and configurations
357
+
358
+ if model_save_dir is not None:
359
+ # To skip saving original pytorch modules
360
+ for submodule_name in cls._submodules:
361
+ delattr(model, submodule_name)
362
+
363
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
364
+ # So config must be saved again, later.
365
+ model.save_pretrained(model_save_dir)
366
+ # FIXME: Here, model touches its submodules such as model.unet,
367
+ # Causing warning messeages.
368
+
369
+ update_dict = {}
370
+ for submodule_name in cls._submodules:
371
+ # replace submodule
372
+ setattr(model, submodule_name, submodules[submodule_name])
373
+ update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
374
+
375
+ # Update config to be able to load from model directory.
376
+ #
377
+ # e.g)
378
+ # update_dict = {
379
+ # "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
380
+ # "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
381
+ # "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
382
+ # }
383
+ model.register_to_config(**update_dict)
384
+
385
+ if model_save_dir:
386
+ # overwrite to replace incorrect config
387
+ model.save_config(model_save_dir)
388
+
389
+ if rbln_config.get("optimize_host_memory") is False:
390
+ # Keep compiled_model objs to further analysis. -> TODO: remove soon...
391
+ model.compiled_models = []
392
+ for name in cls._submodules:
393
+ submodule = getattr(model, name)
394
+ model.compiled_models.extend(submodule.compiled_models)
395
+
396
+ return model
397
+
398
+ @remove_compile_time_kwargs
399
+ def __call__(self, *args, **kwargs):
400
+ return super().__call__(*args, **kwargs)
@@ -57,6 +57,7 @@ _import_structure = {
57
57
  "RBLNWhisperForConditionalGeneration",
58
58
  "RBLNLlamaForCausalLM",
59
59
  "RBLNPhiForCausalLM",
60
+ "RBLNT5EncoderModel",
60
61
  "RBLNT5ForConditionalGeneration",
61
62
  "RBLNLlavaNextForConditionalGeneration",
62
63
  "RBLNMidmLMHeadModel",
@@ -97,6 +98,7 @@ if TYPE_CHECKING:
97
98
  RBLNMistralForCausalLM,
98
99
  RBLNPhiForCausalLM,
99
100
  RBLNQwen2ForCausalLM,
101
+ RBLNT5EncoderModel,
100
102
  RBLNT5ForConditionalGeneration,
101
103
  RBLNWav2Vec2ForCTC,
102
104
  RBLNWhisperForConditionalGeneration,
@@ -12,9 +12,11 @@ class RebelDynamicCache(DynamicCache):
12
12
  `[batch_size, num_heads, seq_len, head_dim]`.
13
13
  """
14
14
 
15
- def __init__(self, current_steps) -> None:
15
+ def __init__(self, position_ids) -> None:
16
16
  super().__init__()
17
- self.current_steps = current_steps
17
+ # batch, _ = position_ids.shape
18
+ # current_steps = [position_ids[b][0] for b in range(batch)]
19
+ self.current_steps = position_ids[:, 0]
18
20
 
19
21
  def assign(
20
22
  self,
@@ -58,13 +60,7 @@ class RebelDynamicCache(DynamicCache):
58
60
  @classmethod
59
61
  def from_input_format(cls, position_ids, num_hidden_layer, *past_key_values) -> "DynamicCache":
60
62
  """Converts a cache in the rbln cache format (list of past_kv) into an equivalent `DynamicCache`."""
61
-
62
- batch, _ = position_ids.shape
63
- current_steps = [position_ids[b][0] for b in range(batch)]
64
-
65
- assert len(current_steps) == batch
66
- cache = cls(current_steps)
67
-
63
+ cache = cls(position_ids)
68
64
  for layer_idx in range(num_hidden_layer):
69
65
  key_states = past_key_values[layer_idx * 2]
70
66
  value_states = past_key_values[layer_idx * 2 + 1]