optimum-rbln 0.7.3a6__py3-none-any.whl → 0.7.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.7.3a6'
21
- __version_tuple__ = version_tuple = (0, 7, 3, 'a6')
20
+ __version__ = version = '0.7.3.post1'
21
+ __version_tuple__ = version_tuple = (0, 7, 3)
@@ -23,7 +23,6 @@ from ..modeling import RBLNModel
23
23
  from ..modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
24
24
  from ..utils.decorator_utils import remove_compile_time_kwargs
25
25
  from ..utils.logging import get_logger
26
- from . import pipelines
27
26
 
28
27
 
29
28
  logger = get_logger(__name__)
@@ -67,6 +66,7 @@ class RBLNDiffusionMixin:
67
66
  as keys in rbln_config
68
67
  """
69
68
 
69
+ _connected_classes = {}
70
70
  _submodules = []
71
71
  _prefix = {}
72
72
 
@@ -103,37 +103,6 @@ class RBLNDiffusionMixin:
103
103
  }
104
104
  )
105
105
  submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
106
- elif hasattr(pipelines, submodule_class_name):
107
- submodule_config = rbln_config.get(submodule_name, {})
108
- submodule_config = copy.deepcopy(submodule_config)
109
-
110
- submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"{submodule_class_name}")
111
- prefix = cls._prefix.get(submodule_name, "")
112
- connected_submodules = cls._connected_classes.get(submodule_name)._submodules
113
- pipe_global_config = {k: v for k, v in submodule_config.items() if k not in connected_submodules}
114
- submodule_config = {k: v for k, v in submodule_config.items() if k in connected_submodules}
115
- for key in submodule_config.keys():
116
- submodule_config[key].update(pipe_global_config)
117
-
118
- for connected_submodule_name in connected_submodules:
119
- connected_submodule_config = rbln_config.pop(prefix + connected_submodule_name, {})
120
- if connected_submodule_name in submodule_config:
121
- submodule_config[connected_submodule_name].update(connected_submodule_config)
122
- else:
123
- submodule_config[connected_submodule_name] = connected_submodule_config
124
-
125
- pipe_global_config = {
126
- k: v for k, v in rbln_config.items() if k != submodule_class_name and not isinstance(v, dict)
127
- }
128
-
129
- for connected_submodule_name in connected_submodules:
130
- for k, v in pipe_global_config.items():
131
- if "guidance_scale" in k:
132
- if prefix + "guidance_scale" == k:
133
- submodule_config[connected_submodule_name]["guidance_scale"] = v
134
- else:
135
- submodule_config[connected_submodule_name][k] = v
136
- rbln_config[submodule_name] = submodule_config
137
106
  else:
138
107
  raise ValueError(f"submodule {submodule_name} isn't supported")
139
108
  return submodule_config
@@ -199,25 +168,8 @@ class RBLNDiffusionMixin:
199
168
  else:
200
169
  # raise error if any of submodules are torch module.
201
170
  model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
202
- if cls._load_connected_pipes:
203
- submodules = []
204
- for submodule in cls._submodules:
205
- submodule_config = rbln_config.pop(submodule, {})
206
- prefix = cls._prefix.get(submodule, "")
207
- connected_submodules = cls._connected_classes.get(submodule)._submodules
208
- for connected_submodule_name in connected_submodules:
209
- connected_submodule_config = submodule_config.pop(connected_submodule_name, {})
210
- if connected_submodule_config:
211
- rbln_config[prefix + connected_submodule_name] = connected_submodule_config
212
- submodules.append(prefix + connected_submodule_name)
213
- pipe_global_config = {k: v for k, v in rbln_config.items() if k not in submodules}
214
- for submodule in submodules:
215
- if submodule in rbln_config:
216
- rbln_config[submodule].update(pipe_global_config)
217
- else:
218
- submodules = cls._submodules
219
-
220
- for submodule_name in submodules:
171
+ rbln_config = cls._flatten_rbln_config(rbln_config)
172
+ for submodule_name in cls._submodules:
221
173
  if isinstance(kwargs.get(submodule_name), torch.nn.Module):
222
174
  raise AssertionError(
223
175
  f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
@@ -266,9 +218,89 @@ class RBLNDiffusionMixin:
266
218
  lora_scales=lora_scales,
267
219
  )
268
220
 
269
- compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
221
+ if cls._load_connected_pipes:
222
+ compiled_submodules = cls._compile_pipelines(model, passed_submodules, model_save_dir, rbln_config)
223
+ else:
224
+ compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
270
225
  return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
271
226
 
227
+ @classmethod
228
+ def _prepare_rbln_config(
229
+ cls,
230
+ rbln_config,
231
+ ) -> Dict[str, Any]:
232
+ prepared_config = {}
233
+ for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
234
+ connected_pipe_config = rbln_config.pop(connected_pipe_name, {})
235
+ prefix = cls._prefix.get(connected_pipe_name, "")
236
+ guidance_scale = rbln_config.pop(f"{prefix}guidance_scale", None)
237
+ if "guidance_scale" not in connected_pipe_config and guidance_scale is not None:
238
+ connected_pipe_config["guidance_scale"] = guidance_scale
239
+ for submodule_name in connected_pipe_cls._submodules:
240
+ submodule_config = rbln_config.pop(prefix + submodule_name, {})
241
+ if submodule_name not in connected_pipe_config:
242
+ connected_pipe_config[submodule_name] = {}
243
+ connected_pipe_config[submodule_name].update(
244
+ {k: v for k, v in submodule_config.items() if k not in connected_pipe_config[submodule_name]}
245
+ )
246
+ prepared_config[connected_pipe_name] = connected_pipe_config
247
+ prepared_config.update(rbln_config)
248
+ return prepared_config
249
+
250
+ @classmethod
251
+ def _flatten_rbln_config(
252
+ cls,
253
+ rbln_config,
254
+ ) -> Dict[str, Any]:
255
+ prepared_config = cls._prepare_rbln_config(rbln_config)
256
+ flattened_config = {}
257
+ pipe_global_config = {k: v for k, v in prepared_config.items() if k not in cls._connected_classes.keys()}
258
+ for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
259
+ connected_pipe_config = prepared_config.pop(connected_pipe_name)
260
+ prefix = cls._prefix.get(connected_pipe_name, "")
261
+ connected_pipe_global_config = {
262
+ k: v for k, v in connected_pipe_config.items() if k not in connected_pipe_cls._submodules
263
+ }
264
+ for submodule_name in connected_pipe_cls._submodules:
265
+ flattened_config[prefix + submodule_name] = connected_pipe_config[submodule_name]
266
+ flattened_config[prefix + submodule_name].update(
267
+ {
268
+ k: v
269
+ for k, v in connected_pipe_global_config.items()
270
+ if k not in flattened_config[prefix + submodule_name]
271
+ }
272
+ )
273
+ flattened_config.update(pipe_global_config)
274
+ return flattened_config
275
+
276
+ @classmethod
277
+ def _compile_pipelines(
278
+ cls,
279
+ model: torch.nn.Module,
280
+ passed_submodules: Dict[str, RBLNModel],
281
+ model_save_dir: Optional[PathLike],
282
+ rbln_config: Dict[str, Any],
283
+ ) -> Dict[str, RBLNModel]:
284
+ compiled_submodules = {}
285
+
286
+ rbln_config = cls._prepare_rbln_config(rbln_config)
287
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._connected_classes.keys()}
288
+ for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
289
+ connected_pipe_submodules = {}
290
+ prefix = cls._prefix.get(connected_pipe_name, "")
291
+ for submodule_name in connected_pipe_cls._submodules:
292
+ connected_pipe_submodules[submodule_name] = passed_submodules.get(prefix + submodule_name, None)
293
+ connected_pipe = getattr(model, connected_pipe_name)
294
+ connected_pipe_config = {}
295
+ connected_pipe_config.update(pipe_global_config)
296
+ connected_pipe_config.update(rbln_config[connected_pipe_name])
297
+ connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
298
+ connected_pipe, connected_pipe_submodules, model_save_dir, connected_pipe_config, prefix
299
+ )
300
+ for submodule_name, compiled_submodule in connected_pipe_compiled_submodules.items():
301
+ compiled_submodules[prefix + submodule_name] = compiled_submodule
302
+ return compiled_submodules
303
+
272
304
  @classmethod
273
305
  def _compile_submodules(
274
306
  cls,
@@ -307,41 +339,6 @@ class RBLNDiffusionMixin:
307
339
  model_save_dir=model_save_dir,
308
340
  rbln_config=submodule_rbln_config,
309
341
  )
310
- elif hasattr(pipelines, submodule.__class__.__name__):
311
- connected_pipe = submodule
312
- connected_pipe_model_save_dir = model_save_dir
313
- connected_pipe_rbln_config = submodule_rbln_config
314
- connected_pipe_cls: RBLNDiffusionMixin = getattr(
315
- importlib.import_module("optimum.rbln"), connected_pipe.__class__.__name__
316
- )
317
- submodule_dict = {}
318
- for name in connected_pipe.config.keys():
319
- if hasattr(connected_pipe, name):
320
- submodule_dict[name] = getattr(connected_pipe, name)
321
- connected_pipe = connected_pipe_cls(**submodule_dict)
322
- connected_pipe_submodules = {}
323
- prefix = cls._prefix.get(submodule_name, "")
324
- for name in connected_pipe_cls._submodules:
325
- if prefix + name in passed_submodules:
326
- connected_pipe_submodules[name] = passed_submodules.get(prefix + name)
327
-
328
- connected_pipe_compiled_submodules = connected_pipe_cls._compile_submodules(
329
- model=connected_pipe,
330
- passed_submodules=connected_pipe_submodules,
331
- model_save_dir=model_save_dir,
332
- rbln_config=connected_pipe_rbln_config,
333
- prefix=prefix,
334
- )
335
- connected_pipe = connected_pipe_cls._construct_pipe(
336
- connected_pipe,
337
- connected_pipe_compiled_submodules,
338
- connected_pipe_model_save_dir,
339
- connected_pipe_rbln_config,
340
- )
341
-
342
- for name in connected_pipe_cls._submodules:
343
- compiled_submodules[prefix + name] = getattr(connected_pipe, name)
344
- submodule = connected_pipe
345
342
  else:
346
343
  raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
347
344
 
@@ -374,23 +371,16 @@ class RBLNDiffusionMixin:
374
371
  @classmethod
375
372
  def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
376
373
  # Construct finalize pipe setup with compiled submodules and configurations
377
- submodule_names = []
378
- for submodule_name in cls._submodules:
379
- submodule = getattr(model, submodule_name)
380
- if hasattr(pipelines, submodule.__class__.__name__):
381
- prefix = cls._prefix.get(submodule_name, "")
382
- connected_pipe_submodules = submodules[submodule_name].__class__._submodules
383
- connected_pipe_submodules = [prefix + name for name in connected_pipe_submodules]
384
- submodule_names += connected_pipe_submodules
385
- setattr(model, submodule_name, submodules[submodule_name])
386
- else:
387
- submodule_names.append(submodule_name)
388
-
389
374
  if model_save_dir is not None:
390
375
  # To skip saving original pytorch modules
391
- for submodule_name in submodule_names:
376
+ for submodule_name in cls._submodules:
392
377
  delattr(model, submodule_name)
393
378
 
379
+ if cls._load_connected_pipes:
380
+ for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
381
+ for submodule_name in connected_pipe_cls._submodules:
382
+ delattr(getattr(model, connected_pipe_name), submodule_name)
383
+
394
384
  # Direct calling of `save_pretrained` causes config.unet = (None, None).
395
385
  # So config must be saved again, later.
396
386
  model.save_pretrained(model_save_dir)
@@ -398,10 +388,15 @@ class RBLNDiffusionMixin:
398
388
  # Causing warning messeages.
399
389
 
400
390
  update_dict = {}
401
- for submodule_name in submodule_names:
391
+ for submodule_name in cls._submodules:
402
392
  # replace submodule
403
393
  setattr(model, submodule_name, submodules[submodule_name])
404
394
  update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
395
+ if cls._load_connected_pipes:
396
+ for connected_pipe_name, connected_pipe_cls in cls._connected_classes.items():
397
+ prefix = cls._prefix.get(connected_pipe_name, "")
398
+ for submodule_name in connected_pipe_cls._submodules:
399
+ setattr(getattr(model, connected_pipe_name), submodule_name, submodules[prefix + submodule_name])
405
400
 
406
401
  # Update config to be able to load from model directory.
407
402
  #
@@ -420,16 +415,9 @@ class RBLNDiffusionMixin:
420
415
  if rbln_config.get("optimize_host_memory") is False:
421
416
  # Keep compiled_model objs to further analysis. -> TODO: remove soon...
422
417
  model.compiled_models = []
423
- if model._load_connected_pipes:
424
- for name in cls._submodules:
425
- connected_pipe = getattr(model, name)
426
- for submodule_name in connected_pipe.__class__._submodules:
427
- submodule = getattr(connected_pipe, submodule_name)
428
- model.compiled_models.extend(submodule.compiled_models)
429
- else:
430
- for name in cls._submodules:
431
- submodule = getattr(model, name)
432
- model.compiled_models.extend(submodule.compiled_models)
418
+ for name in cls._submodules:
419
+ submodule = getattr(model, name)
420
+ model.compiled_models.extend(submodule.compiled_models)
433
421
 
434
422
  return model
435
423
 
@@ -39,7 +39,7 @@ from .pipeline_kandinsky2_2_prior import RBLNKandinskyV22PriorPipeline
39
39
  class RBLNKandinskyV22CombinedPipeline(RBLNDiffusionMixin, KandinskyV22CombinedPipeline):
40
40
  original_class = KandinskyV22CombinedPipeline
41
41
  _connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22Pipeline}
42
- _submodules = ["prior_pipe", "decoder_pipe"]
42
+ _submodules = ["prior_image_encoder", "prior_text_encoder", "prior_prior", "unet", "movq"]
43
43
  _prefix = {"prior_pipe": "prior_"}
44
44
 
45
45
  def __init__(
@@ -90,7 +90,7 @@ class RBLNKandinskyV22CombinedPipeline(RBLNDiffusionMixin, KandinskyV22CombinedP
90
90
  class RBLNKandinskyV22Img2ImgCombinedPipeline(RBLNDiffusionMixin, KandinskyV22Img2ImgCombinedPipeline):
91
91
  original_class = KandinskyV22Img2ImgCombinedPipeline
92
92
  _connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22Img2ImgPipeline}
93
- _submodules = ["prior_pipe", "decoder_pipe"]
93
+ _submodules = ["prior_image_encoder", "prior_text_encoder", "prior_prior", "unet", "movq"]
94
94
  _prefix = {"prior_pipe": "prior_"}
95
95
 
96
96
  def __init__(
@@ -141,7 +141,7 @@ class RBLNKandinskyV22Img2ImgCombinedPipeline(RBLNDiffusionMixin, KandinskyV22Im
141
141
  class RBLNKandinskyV22InpaintCombinedPipeline(RBLNDiffusionMixin, KandinskyV22InpaintCombinedPipeline):
142
142
  original_class = KandinskyV22InpaintCombinedPipeline
143
143
  _connected_classes = {"prior_pipe": RBLNKandinskyV22PriorPipeline, "decoder_pipe": RBLNKandinskyV22InpaintPipeline}
144
- _submodules = ["prior_pipe", "decoder_pipe"]
144
+ _submodules = ["prior_image_encoder", "prior_text_encoder", "prior_prior", "unet", "movq"]
145
145
  _prefix = {"prior_pipe": "prior_"}
146
146
 
147
147
  def __init__(
@@ -108,6 +108,8 @@ class RBLNBartModel(RBLNModel):
108
108
 
109
109
 
110
110
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
111
+ support_paged_causal_attn = True
112
+
111
113
  @classmethod
112
114
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
113
115
  enc_max_seq_len = (
@@ -98,6 +98,18 @@ def validate_attention_method(
98
98
  "this requirement, or consider switching `rbln_attn_impl` to 'eager' for shorter lengths."
99
99
  )
100
100
 
101
+ if rbln_kvcache_block_size is not None:
102
+ if rbln_attn_impl == "flash_attn" and rbln_kvcache_partition_len != rbln_kvcache_block_size:
103
+ raise ValueError(
104
+ f" When using 'flash attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
105
+ f"must always be set equal to the `rbln_kvcache_partition_len` {rbln_kvcache_partition_len}."
106
+ )
107
+ elif rbln_attn_impl == "eager" and rbln_kvcache_block_size != rbln_max_seq_len:
108
+ raise ValueError(
109
+ f" When using 'eager attention', the `rbln_kvcache_block_size` ({rbln_kvcache_block_size}) "
110
+ f"must always be set equal to the `rbln_max_seq_len` {rbln_max_seq_len}."
111
+ )
112
+
101
113
  return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
102
114
 
103
115
 
@@ -50,6 +50,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
50
50
  runtime: rebel.Runtime,
51
51
  batch_size: int,
52
52
  dec_max_seq_len: int,
53
+ support_paged_causal_attn: Optional[bool] = None,
53
54
  use_attention_mask: Optional[bool] = None,
54
55
  **kwargs: Any,
55
56
  ) -> None:
@@ -57,7 +58,10 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
57
58
  self.batch_size = batch_size
58
59
  self.dec_max_seq_len = dec_max_seq_len
59
60
  self.use_attention_mask = use_attention_mask
60
- self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
61
+ if support_paged_causal_attn:
62
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
63
+ else:
64
+ self.default_block_tables = None
61
65
 
62
66
  def forward(
63
67
  self,
@@ -94,7 +98,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
94
98
  decoder_attention_mask if self.use_attention_mask else None,
95
99
  attention_mask,
96
100
  cache_position,
97
- block_tables,
101
+ block_tables=block_tables,
98
102
  )
99
103
 
100
104
  return Seq2SeqLMOutput(logits=lm_logits)
@@ -115,6 +119,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
115
119
 
116
120
  main_input_name = "input_ids"
117
121
  auto_model_class = AutoModelForSeq2SeqLM
122
+ support_paged_causal_attn = None
118
123
 
119
124
  def __post_init__(self, **kwargs):
120
125
  batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -130,6 +135,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
130
135
  main_input_name="input_ids",
131
136
  batch_size=batch_size,
132
137
  dec_max_seq_len=dec_max_seq_len,
138
+ support_paged_causal_attn=self.support_paged_causal_attn,
133
139
  use_attention_mask=self.use_attention_mask,
134
140
  )
135
141
 
@@ -186,13 +192,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
186
192
  rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
187
193
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
188
194
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
189
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
190
195
 
191
- if rbln_use_attention_mask is None:
192
- rbln_use_attention_mask = False
193
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
194
- if rbln_npu == "RBLN-CA02":
195
- rbln_use_attention_mask = True
196
+ if cls.support_paged_causal_attn:
197
+ rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
198
+ if rbln_use_attention_mask is None:
199
+ rbln_use_attention_mask = False
200
+ rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
201
+ if rbln_npu == "RBLN-CA02":
202
+ rbln_use_attention_mask = True
203
+ else:
204
+ rbln_use_attention_mask = True
196
205
 
197
206
  n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
198
207
  n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
@@ -265,11 +274,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
265
274
  [rbln_batch_size, 1],
266
275
  "int32",
267
276
  ),
268
- (
269
- "block_tables",
270
- [rbln_batch_size, 1],
271
- "int16",
272
- ),
273
277
  ]
274
278
  dec_input_info.extend(
275
279
  [
@@ -302,6 +306,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
302
306
  ]
303
307
  )
304
308
 
309
+ if cls.support_paged_causal_attn:
310
+ dec_input_info.insert(3, ("block_tables", [rbln_batch_size, 1], "int16"))
305
311
  if rbln_use_attention_mask:
306
312
  dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
307
313
 
@@ -13,9 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
17
17
 
18
- import rebel
19
18
  import torch
20
19
  from transformers import (
21
20
  AutoModelForTextEncoding,
@@ -23,7 +22,7 @@ from transformers import (
23
22
  T5EncoderModel,
24
23
  T5ForConditionalGeneration,
25
24
  )
26
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
25
+ from transformers.modeling_outputs import BaseModelOutput
27
26
 
28
27
  from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
29
28
  from ....modeling import RBLNModel
@@ -58,63 +57,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
58
57
  )
59
58
 
60
59
 
61
- class RBLNRuntimeEncoder(RBLNPytorchRuntime):
62
- mandatory_members = ["main_input_name"]
63
-
64
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
65
- _ = super().forward(*args, **kwargs)
66
- return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
67
-
68
-
69
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
70
- mandatory_members = ["main_input_name"]
71
-
72
- def __init__(
73
- self,
74
- runtime: rebel.Runtime,
75
- batch_size: int,
76
- dec_max_seq_len: int,
77
- **kwargs: Any,
78
- ) -> None:
79
- super().__init__(runtime, **kwargs)
80
- self.batch_size = batch_size
81
- self.dec_max_seq_len = dec_max_seq_len
82
-
83
- def forward(
84
- self,
85
- decoder_input_ids: Optional[torch.LongTensor] = None,
86
- attention_mask: Optional[torch.FloatTensor] = None,
87
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
88
- cache_position: Optional[torch.Tensor] = None,
89
- **kwargs,
90
- ) -> Tuple[torch.FloatTensor]:
91
- batch_size = decoder_input_ids.shape[0]
92
- if batch_size != self.batch_size:
93
- raise RuntimeError(
94
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
95
- )
96
-
97
- if batch_size != cache_position.shape[0]:
98
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
99
-
100
- for b_idx in range(self.batch_size):
101
- decoding_step = cache_position[b_idx].item()
102
- if not (0 <= decoding_step < self.dec_max_seq_len):
103
- raise ValueError(
104
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
105
- )
106
- decoder_attention_mask[b_idx, : decoding_step + 1] = 1
107
-
108
- lm_logits = super().forward(
109
- decoder_input_ids,
110
- decoder_attention_mask,
111
- attention_mask,
112
- cache_position,
113
- )
114
-
115
- return Seq2SeqLMOutput(logits=lm_logits)
116
-
117
-
118
60
  class T5EncoderWrapper(torch.nn.Module):
119
61
  def __init__(self, model: "T5EncoderModel") -> None:
120
62
  super().__init__()
@@ -247,20 +189,7 @@ class RBLNT5EncoderModel(RBLNModel):
247
189
 
248
190
 
249
191
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
250
- def __post_init__(self, **kwargs):
251
- batch_size = self.rbln_config.model_cfg["batch_size"]
252
- dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
253
-
254
- self.encoder = RBLNRuntimeEncoder(
255
- runtime=self.model[0],
256
- main_input_name="input_ids",
257
- )
258
- self.decoder = RBLNRuntimeDecoder(
259
- runtime=self.model[1],
260
- main_input_name="input_ids",
261
- batch_size=batch_size,
262
- dec_max_seq_len=dec_max_seq_len,
263
- )
192
+ support_causal_paged_attn = False
264
193
 
265
194
  @classmethod
266
195
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
@@ -279,139 +208,3 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
279
208
  return redirect(val)
280
209
 
281
210
  return val
282
-
283
- @classmethod
284
- def _get_rbln_config(
285
- cls,
286
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
287
- model_config: "PretrainedConfig",
288
- rbln_kwargs: Dict[str, Any] = {},
289
- ) -> RBLNConfig:
290
- rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
291
- rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
292
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
293
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
294
-
295
- n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
296
- n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
297
- d_kv = (
298
- model_config.d_kv
299
- if hasattr(model_config, "d_kv")
300
- else model_config.d_model // model_config.encoder_attention_heads
301
- )
302
-
303
- max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
304
- model_config, "max_position_embeddings", None
305
- )
306
-
307
- rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
308
- if rbln_pad_token_id is None:
309
- rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
310
- if rbln_pad_token_id is None:
311
- rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
312
- if rbln_pad_token_id is None:
313
- rbln_pad_token_id = -1
314
-
315
- if rbln_enc_max_seq_len is None:
316
- rbln_enc_max_seq_len = max_position_embeddings
317
- if rbln_enc_max_seq_len is None:
318
- for tokenizer in preprocessors:
319
- if hasattr(tokenizer, "model_max_length"):
320
- rbln_enc_max_seq_len = tokenizer.model_max_length
321
- break
322
- if rbln_enc_max_seq_len is None:
323
- raise ValueError("`rbln_enc_max_seq_len` should be specified!")
324
- if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
325
- raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
326
-
327
- if rbln_dec_max_seq_len is None:
328
- rbln_dec_max_seq_len = max_position_embeddings
329
- if rbln_dec_max_seq_len is None:
330
- for tokenizer in preprocessors:
331
- if hasattr(tokenizer, "model_max_length"):
332
- rbln_dec_max_seq_len = tokenizer.model_max_length
333
- break
334
- if rbln_dec_max_seq_len is None:
335
- raise ValueError("`rbln_dec_max_seq_len` should be specified!")
336
-
337
- if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
338
- raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
339
-
340
- # model input info
341
- enc_input_info = [
342
- ("input_ids", [1, rbln_enc_max_seq_len], "int64"),
343
- ("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
344
- (
345
- "cross_key_value_states",
346
- [
347
- n_layer * 2,
348
- rbln_batch_size,
349
- n_head,
350
- rbln_enc_max_seq_len,
351
- d_kv,
352
- ],
353
- "float32",
354
- ),
355
- ("block_tables", [1], "int16"),
356
- ]
357
-
358
- dec_input_info = [
359
- ("input_ids", [rbln_batch_size, 1], "int64"),
360
- ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
361
- ("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
362
- (
363
- "cache_position",
364
- [rbln_batch_size, 1],
365
- "int32",
366
- ),
367
- ]
368
- dec_input_info.extend(
369
- [
370
- (
371
- "cross_key_value_states",
372
- [
373
- n_layer * 2,
374
- rbln_batch_size,
375
- n_head,
376
- rbln_enc_max_seq_len,
377
- d_kv,
378
- ],
379
- "float32",
380
- )
381
- ]
382
- )
383
- dec_input_info.extend(
384
- [
385
- (
386
- f"self_key_value_states_{i}",
387
- [
388
- rbln_batch_size,
389
- n_head,
390
- rbln_dec_max_seq_len,
391
- d_kv,
392
- ],
393
- "float32",
394
- )
395
- for i in range(n_layer * 2)
396
- ]
397
- )
398
-
399
- enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
400
- dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
401
-
402
- rbln_config = RBLNConfig(
403
- rbln_cls=cls.__name__,
404
- compile_cfgs=[enc_compile_config, dec_compile_config],
405
- rbln_kwargs=rbln_kwargs,
406
- )
407
-
408
- rbln_config.model_cfg.update(
409
- {
410
- "enc_max_seq_len": rbln_enc_max_seq_len,
411
- "dec_max_seq_len": rbln_dec_max_seq_len,
412
- "batch_size": rbln_batch_size,
413
- "pad_token_id": rbln_pad_token_id,
414
- }
415
- )
416
-
417
- return rbln_config
@@ -28,6 +28,13 @@ class VersionCompat:
28
28
 
29
29
 
30
30
  RBLN_VERSION_COMPATS = {
31
+ "0.7.3": [
32
+ VersionCompat(
33
+ package_name="rebel-compiler",
34
+ min_version="0.7.3",
35
+ max_version="0.7.4",
36
+ ),
37
+ ],
31
38
  "0.7.2": [
32
39
  VersionCompat(
33
40
  package_name="rebel-compiler",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.7.3a6
3
+ Version: 0.7.3.post1
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,10 +1,10 @@
1
1
  optimum/rbln/__init__.py,sha256=ZDzXcl-oAcYJhKjJMpotjbTih9awo7HzUb6T3MUEP6Q,6894
2
- optimum/rbln/__version__.py,sha256=9voT1MrnPHKvqTeiZK8bNEZcPseZOq7N_U5etptnmTE,519
2
+ optimum/rbln/__version__.py,sha256=aegWGVZeZJ9bIegWWNAgPL2y9SAs5kPTsXCQi0EZ9go,517
3
3
  optimum/rbln/modeling.py,sha256=nJsAs5zs--VVOYGFjYNpqfxYIemJIK4Lr0WEzlDLdP0,8390
4
4
  optimum/rbln/modeling_base.py,sha256=dNCL-BhrWCpuOVkZaj8-MW567Tf4lLo3p3Z3ldjWJfU,21779
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
6
6
  optimum/rbln/diffusers/__init__.py,sha256=Hq87CbtiCy85YmK2SB-OmUyfv77oe3j4bsTenTRnu6w,3623
7
- optimum/rbln/diffusers/modeling_diffusers.py,sha256=zqVNgH9oeOx2iNE7VsW_FinVf4s6G5Idyh4TKz7XJJg,21116
7
+ optimum/rbln/diffusers/modeling_diffusers.py,sha256=IS6Mlgexofap7f9Lefk5cKFP7ejSG_oWN3v2PX9_IDQ,20118
8
8
  optimum/rbln/diffusers/models/__init__.py,sha256=mkCvJyH1KcwrsUvYSq_bVC79oOfyqtBSFDyPS1_48wA,1478
9
9
  optimum/rbln/diffusers/models/controlnet.py,sha256=EM_HlzCdaZdnnK0oGpY2fQeigPqHhlwh4NHCzlmoumI,10512
10
10
  optimum/rbln/diffusers/models/autoencoders/__init__.py,sha256=dg17ZTUsiqTcbIaEE4fqew9uRbao0diQ21PXvRKIqKg,679
@@ -25,7 +25,7 @@ optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py,sha256=
25
25
  optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py,sha256=RfwxNX_zQWFtvvFQJ5bt3qtHbdYdQV_3XLHm9WYCKOs,46084
26
26
  optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py,sha256=I4YQq2HfA3xONbWsdJ870IEJPyLWeCDDG-UCJsu9YO8,1035
27
27
  optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py,sha256=aNFGOjth8tDvPrjYLbRWrkHr6p-8AFgcQx1Qay1fw70,904
28
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py,sha256=unqFDviA7dnx0yuo8L8tXVj2mjFYCPm7C9dcpdWBICc,6882
28
+ optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py,sha256=BVXOpdrezWVTCibpuAMu9KkD5oEQUY00cSqm6dFbTnk,7020
29
29
  optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py,sha256=fEs-WgJqWs5zvuCkKb7MuZokH9Mi6q-0DOEKxzfWxzo,932
30
30
  optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py,sha256=Ad2ZYCXaMiYpB0mz-8X1CGhILxrVbt7rRIXt6IPwYBM,932
31
31
  optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py,sha256=Mf7tzrXetwCgt7LuXfkX-CX1hltLgNZdwF9bHxAbDJM,874
@@ -55,13 +55,13 @@ optimum/rbln/transformers/models/auto/auto_factory.py,sha256=IK9jFrJ3EEzYQa9_aKp
55
55
  optimum/rbln/transformers/models/auto/modeling_auto.py,sha256=Un9qoqdy3dO8JBza_bTJF_6_fRVNM9QisihSgTRFI-o,3933
56
56
  optimum/rbln/transformers/models/bart/__init__.py,sha256=32HPe0_GIO0hp9U464Iv6Jd7M-1nop9g8hA1UZMHhyw,674
57
57
  optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=Oo-Cdne7igKEex8wwP-gztKJHgs5GLHQjK1oc3IZIDE,5801
58
- optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=iI3ubPOVvHmhLt0wEz_vkOfMyNTHVNjmnkLtbpOX760,5797
58
+ optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=6IpWXlBCd02v66KF77oEWfrv8-FnPBYjjjL_8KZL3Ow,5835
59
59
  optimum/rbln/transformers/models/bert/__init__.py,sha256=YVV7k_laU6yJBawZrgjIWjRmIF-Y4oQQHqyf8lsraQs,691
60
60
  optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=p3utRqf3dv9_RkHwaMCa1EfXttNJkqCJUIZo3CeZ9YY,4674
61
61
  optimum/rbln/transformers/models/clip/__init__.py,sha256=H9vuBwrmFO0-CqZhXUrKF-uQL6igCqMlqrT1X_ELaAI,754
62
62
  optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=NiSm7bHs4SReHDUr53BBWSX0Y8bkKOeUSpsBDrp8YDw,6628
63
63
  optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
64
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=7OIKteJLKNxOLOg0w3lLOM7TxZovQn4jkglI9wRkrtQ,40609
64
+ optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=m93-qKN7NMw3i0XDmFmttmRIRK4np_fWtLFlBb2RFgU,41351
65
65
  optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=uGdPGcFrWm2gAwFLjfBiALwFsl49VGCReVi4NUfOPxM,38898
66
66
  optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
67
67
  optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
@@ -92,10 +92,10 @@ optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxP
92
92
  optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
93
93
  optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
94
94
  optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
95
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=NPfJf9Uk_bYOae7hXGHwteGiWH0va63Z-D93RmAMENg,17611
95
+ optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=9Pf9Y86ABDfhwIenlZqYfgqjbyFmtKBiPnbCD7zxw4M,18017
96
96
  optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=tvzacIZam1sIr_1BvvZ_fDr8u5dXAiYiynFdX9tArtY,18877
97
97
  optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
98
- optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=nKRR3eH1EAu1YkKvhlqGyTrJXDRd-IWB5LOeG9jrcb4,16021
98
+ optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=8PAhPlYT1dmpcWM7hUMmZV9lPd4d75CuMuFen1pzr3Q,8088
99
99
  optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=AArCQhZRETVM583wlIRzMFOSYq7t2nzxaAeyhZxyxKk,9508
100
100
  optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
101
101
  optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
@@ -110,13 +110,13 @@ optimum/rbln/transformers/utils/rbln_quantization.py,sha256=gwBVHf97sQgPNmGa0wq8
110
110
  optimum/rbln/utils/__init__.py,sha256=ieDBT2VFTt2E0M4v_POLBpuGW9LxSydpb_DuPd6PQqc,712
111
111
  optimum/rbln/utils/decorator_utils.py,sha256=xu-TrsNi33SRC2a7DBsyoo6-pEQxWKZPZSmM9QlDe2Y,3745
112
112
  optimum/rbln/utils/hub.py,sha256=bNmOJGEO9Jfux4Cg8Xli-898I4mxk20KuwQOhP0Zs1U,4198
113
- optimum/rbln/utils/import_utils.py,sha256=n4HcvZPzFW2ytl45qJ4ZQYlrRSoOb0-nnqhyT2_JA8M,4224
113
+ optimum/rbln/utils/import_utils.py,sha256=uMldLJmDVMj5uHvxBfb96uV29bfGEDvlksLY26GOHAs,4389
114
114
  optimum/rbln/utils/logging.py,sha256=VKKBmlQSdg6iZCGmAXaWYiW67K84jyp1QJhLQSSjPPE,3453
115
115
  optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvVxDvs,1278
116
116
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
117
117
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
118
118
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
119
- optimum_rbln-0.7.3a6.dist-info/METADATA,sha256=TGw8TCIfBQ9RWlzxf5JI16Zoy-xoEodnBO8m6SKXBsk,5300
120
- optimum_rbln-0.7.3a6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
121
- optimum_rbln-0.7.3a6.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
122
- optimum_rbln-0.7.3a6.dist-info/RECORD,,
119
+ optimum_rbln-0.7.3.post1.dist-info/METADATA,sha256=dKER74SsqGQwVQgTXVM854y97xzhfRl5LKaGedd4IIw,5304
120
+ optimum_rbln-0.7.3.post1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
121
+ optimum_rbln-0.7.3.post1.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
122
+ optimum_rbln-0.7.3.post1.dist-info/RECORD,,