optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,61 +21,23 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect # noqa: I001
24
+ import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import Any, Callable
29
27
 
30
- import torch # noqa: F401
31
- import rebel # noqa: F401
28
+ from transformers import LlamaForCausalLM, PreTrainedModel
32
29
 
33
- from optimum.exporters import TasksManager
34
- from transformers import AutoModelForCausalLM, LlamaForCausalLM, PretrainedConfig, AutoConfig
35
- from transformers.modeling_outputs import CausalLMOutputWithPast
36
-
37
- from ...generation.utils import RBLNGenerationMixin
38
- from ....modeling_base import RBLNBaseModel
39
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
40
- from ....utils.runtime_utils import RBLNPytorchRuntime
41
- from ....utils.save_utils import maybe_save_preprocessors
42
-
43
-
44
- # FIXME:: Merge Two architecture Codes
45
- from .llama_architecture import (
46
- LlamaWrapper,
47
- wrap_llama,
48
- unwrap_llama,
49
- )
50
-
51
- from .llama_architecture_cb import (
52
- LlamaDynamicBatchWrapper as LlamaWrapper_cb,
53
- wrap_llama as wrap_llama_cb,
54
- )
30
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
+ from .llama_architecture import LlamaWrapper
55
32
 
56
33
 
57
34
  logger = logging.getLogger(__name__)
58
35
 
59
- if TYPE_CHECKING:
60
- from transformers import (
61
- AutoFeatureExtractor,
62
- AutoProcessor,
63
- AutoTokenizer,
64
- PretrainedConfig,
65
- )
66
-
67
36
 
68
- SUPPORTED_BATCHING_MODES = ["static", "vllm"]
69
-
70
-
71
- class RBLNRuntimeModel(RBLNPytorchRuntime):
72
- mandatory_members = ["main_input_name"]
73
-
74
-
75
- class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
37
+ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
76
38
  """
77
39
  The Llama Model transformer with a language modeling head (linear layer) on top.
78
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
40
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
79
41
 
80
42
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
81
43
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -83,273 +45,9 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
83
45
  - compiling the resulting graph using the RBLN compiler.
84
46
  """
85
47
 
86
- model_type = "rbln_model"
87
- main_input_name = "input_ids"
88
- auto_model_class = AutoModelForCausalLM
89
-
90
- def __post_init__(self, **kwargs):
91
- self.batch_size = self.rbln_config.meta["rbln_batch_size"]
92
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
93
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
94
- self.use_continuous_batch = self.rbln_config.meta["rbln_batching"] == "vllm"
95
-
96
- prefill_batch_size = self.batch_size if not self.use_continuous_batch else 1
97
- self.prefill_attention_mask = torch.zeros(
98
- prefill_batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
99
- )
100
- self.causal_mask = 1 - torch.triu(
101
- torch.ones(prefill_batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
102
- )
103
- self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
104
-
105
- self.prefill_decoder = RBLNRuntimeModel(runtime=self.runtimes[0], main_input_name="input_ids")
106
- self.decoder = RBLNRuntimeModel(runtime=self.runtimes[1], main_input_name="input_ids")
107
- self.past_cached_length = 0
108
- self.right_padding = True
109
-
110
- @classmethod
111
- @torch.no_grad()
112
- def _export(
113
- cls,
114
- model_id: str,
115
- config: "PretrainedConfig",
116
- use_auth_token: Optional[Union[bool, str]] = None,
117
- revision: Optional[str] = None,
118
- force_download: bool = False,
119
- cache_dir: Optional[str] = None,
120
- subfolder: str = "",
121
- local_files_only: bool = False,
122
- trust_remote_code: bool = False,
123
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
124
- **kwargs,
125
- ) -> "RBLNLlamaForCausalLM":
126
- task = kwargs.pop("task", None)
127
- if task is None:
128
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
129
-
130
- if model_save_dir is None:
131
- save_dir = TemporaryDirectory()
132
- save_dir_path = Path(save_dir.name)
133
- else:
134
- save_dir = model_save_dir
135
- if isinstance(save_dir, TemporaryDirectory):
136
- save_dir_path = Path(model_save_dir.name)
137
- else:
138
- save_dir_path = Path(model_save_dir)
139
- save_dir_path.mkdir(exist_ok=True)
140
-
141
- def update_configs(kwargs):
142
- hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
143
- max_seq_len = kwargs.get("rbln_max_seq_len", None)
144
- if max_seq_len is not None:
145
- if max_seq_len <= hf_max_position_embeddings:
146
- kwargs.update({"max_position_embeddings": max_seq_len})
147
- else:
148
- raise ValueError("`max_seq_len` should be less or equal than max_position_embeddings!")
149
-
150
- kwargs.update(
151
- {
152
- "torchscript": True,
153
- "return_dict": False,
154
- "use_cache": True,
155
- "torch_dtype": torch.float32,
156
- "_attn_implementation": "eager",
157
- }
158
- )
159
-
160
- return kwargs
161
-
162
- kwargs = update_configs(kwargs)
163
-
164
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
165
-
166
- # FIXME :: This should be moved when wrapping removed.
167
- use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
168
- origin_mehtods = wrap_llama_cb() if use_continuous_batch else wrap_llama()
169
-
170
- model: LlamaForCausalLM = TasksManager.get_model_from_task(
171
- task=task,
172
- model_name_or_path=model_id,
173
- subfolder=subfolder,
174
- revision=revision,
175
- framework="pt",
176
- cache_dir=cache_dir,
177
- use_auth_token=use_auth_token,
178
- local_files_only=local_files_only,
179
- force_download=force_download,
180
- trust_remote_code=trust_remote_code,
181
- **kwargs,
182
- )
183
-
184
- if config is None:
185
- config = model.config
186
-
187
- config.save_pretrained(save_dir_path)
188
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
189
-
190
- # Get compilation arguments
191
- if rbln_config_kwargs.get("rbln_config", None) is None:
192
- rbln_config = cls.get_rbln_config(
193
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
194
- )
195
-
196
- def compile_llama(use_continuous_batch, wrapper_cls):
197
- wrapped_model = wrapper_cls(model).eval()
198
-
199
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
200
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
201
-
202
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
203
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
204
-
205
- if use_continuous_batch:
206
- batch_index_index = 3
207
- dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
208
-
209
- prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs)
210
- dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs)
211
-
212
- prefill_ir = rebel.torchscript_to_ir(
213
- prefill_scripted_model,
214
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
215
- )
216
- dec_ir = rebel.torchscript_to_ir(
217
- dec_scripted_model,
218
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
219
- )
220
-
221
- # Caching prefill_decoder/decoder I/O
222
- cache_index_offset = 4 if use_continuous_batch else 3
223
- connections = [
224
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
225
- for i in range(model.config.num_hidden_layers * 2)
226
- ]
227
-
228
- compiled_model = rebel.compile(
229
- prefill_ir,
230
- dec_ir,
231
- connections=connections,
232
- fusion=prefill_rbln_runtime_config.fusion,
233
- npu=prefill_rbln_runtime_config.npu,
234
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
235
- use_weight_sharing=True,
236
- )
237
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
238
-
239
- wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
240
- compile_llama(use_continuous_batch=use_continuous_batch, wrapper_cls=wrapper_cls)
241
- unwrap_llama(origin_mehtods)
242
-
243
- rbln_config.save(save_dir_path)
244
-
245
- return cls._from_pretrained(
246
- model_id=save_dir_path,
247
- config=config,
248
- model_save_dir=save_dir,
249
- **rbln_constructor_kwargs,
250
- **kwargs,
251
- )
252
-
253
48
  @classmethod
254
- def _get_rbln_config(
255
- cls,
256
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
257
- model_config: "PretrainedConfig",
258
- rbln_max_seq_len: Optional[int] = None,
259
- rbln_batch_size: Optional[int] = None,
260
- rbln_batching: Optional[str] = None,
261
- ) -> RBLNConfig:
262
- meta = {}
263
-
264
- prefill_chunk_size = 128
265
- if rbln_max_seq_len is None:
266
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
267
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
268
- rbln_batching = "static" if rbln_batching is None else rbln_batching
269
-
270
- meta["rbln_max_seq_len"] = rbln_max_seq_len
271
- meta["rbln_batch_size"] = rbln_batch_size
272
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
273
- meta["rbln_batching"] = rbln_batching
274
- use_continuous_batching = meta["rbln_batching"] == "vllm"
275
-
276
- if rbln_batching not in SUPPORTED_BATCHING_MODES:
277
- raise ValueError(
278
- f'rbln_batching="{rbln_batching}" is not a supported batch mode, '
279
- f"Possible: {SUPPORTED_BATCHING_MODES}"
280
- )
281
-
282
- def get_input_info(
283
- batch_size, # should be 1 if continous batch prefill
284
- query_length,
285
- continuous_batch=False, # determines the shape of `cache position`
286
- ):
287
- input_info = [
288
- ("input_ids", [batch_size, query_length], "int64"),
289
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
290
- (
291
- "cache_position",
292
- [batch_size, query_length] if continuous_batch else [],
293
- "int32",
294
- ),
295
- ]
296
-
297
- if continuous_batch:
298
- input_info.append(("batch_position", [], "int16"))
299
-
300
- input_info.extend(
301
- [
302
- (
303
- f"past_key_values_{i}",
304
- [
305
- rbln_batch_size,
306
- model_config.num_key_value_heads,
307
- rbln_max_seq_len,
308
- model_config.hidden_size // model_config.num_attention_heads,
309
- ],
310
- "float32",
311
- )
312
- for i in range(model_config.num_hidden_layers * 2)
313
- ]
314
- )
315
-
316
- return input_info
317
-
318
- prefill_input_info = get_input_info(
319
- batch_size=1 if use_continuous_batching else rbln_batch_size,
320
- query_length=prefill_chunk_size,
321
- continuous_batch=use_continuous_batching,
322
- )
323
- dec_input_info = get_input_info(
324
- batch_size=rbln_batch_size,
325
- query_length=1,
326
- continuous_batch=use_continuous_batching,
327
- )
328
-
329
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
330
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
331
-
332
- dec_rbln_runtime_config.batch_size = rbln_batch_size
333
-
334
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
335
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
336
- _rbln_meta=meta,
337
- )
338
-
339
- return rbln_config
340
-
341
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
342
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
343
- return [
344
- self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
345
- self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
346
- ]
347
-
348
- def get_decoder(self):
349
- return self.decoder
350
-
351
- def can_generate(self):
352
- return True
49
+ def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
50
+ return LlamaWrapper(model, rbln_max_seq_len).eval()
353
51
 
354
52
  def __getattr__(self, __name: str) -> Any:
355
53
  def redirect(func):
@@ -361,161 +59,3 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
361
59
  return redirect(val)
362
60
 
363
61
  return val
364
-
365
- def _reorder_cache(self, past_key_values, beam_idx):
366
- raise NotImplementedError
367
-
368
- # args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
369
- def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
370
- batch_size, cur_len = input_ids.shape
371
- past_cached_length = past_key_values
372
-
373
- # In greedy decoding
374
- if past_cached_length == 0:
375
- # padding with prefill_chunk_size
376
- # TODO left padding + left padding has issue on stoppingcriteria(max_len)
377
- if cur_len % self.prefill_chunk_size != 0:
378
- pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
379
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
380
-
381
- # padding_side
382
- if batch_size > 1 and torch.all(attention_mask[..., -1] == 1):
383
- self.right_padding = False
384
-
385
- if self.right_padding:
386
- self.rightpad_max_len = cur_len
387
- prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
388
- self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len # dummy_decoder generation length
389
- query_length = prompt_min_len.item()
390
- else:
391
- query_length = cur_len - past_cached_length
392
- self.prompt_length = query_length
393
- self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
394
-
395
- attention_mask = self.prefill_attention_mask.clone()
396
- cache_position = torch.tensor(0, dtype=torch.int32)
397
-
398
- else:
399
- if self.right_padding:
400
- attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
401
- attention_mask[:, :, :, : past_cached_length + 1] = 1
402
- input_ids = input_ids[:, past_cached_length : past_cached_length + 1].contiguous()
403
- else:
404
- attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - cur_len))
405
- attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
406
- input_ids = input_ids[:, -1:]
407
-
408
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
409
- query_length = 1
410
-
411
- model_inputs = {
412
- "input_ids": input_ids,
413
- "past_key_values": past_key_values,
414
- "attention_mask": attention_mask,
415
- "cache_position": cache_position,
416
- "query_length": query_length,
417
- }
418
-
419
- return model_inputs
420
-
421
- def forward(self, *args, **kwargs):
422
- if self.use_continuous_batch:
423
- return self.forward_cb(*args, **kwargs)
424
- else:
425
- return self.forward_static(*args, **kwargs)
426
-
427
- def forward_static(
428
- self,
429
- input_ids: torch.LongTensor = None,
430
- attention_mask: Optional[torch.Tensor] = None,
431
- past_key_values: int = None,
432
- cache_position: Optional[torch.Tensor] = None,
433
- query_length: Optional[torch.Tensor] = None,
434
- **kwargs,
435
- ) -> Tuple[torch.FloatTensor]:
436
- if past_key_values is not None:
437
- past_key_values += query_length
438
-
439
- # prefill_decoder
440
- if cache_position == 0:
441
- for step in range(0, query_length, self.prefill_chunk_size):
442
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
443
- attention_mask[:, :, :, :step] = 1
444
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
445
- if not self.right_padding:
446
- attention_mask[:, :, :, : self.prompt_length] &= self.prompt_attn_mask[:, :, :, :]
447
-
448
- outputs = self.prefill_decoder(
449
- input_ids=sliced_input_ids.contiguous(),
450
- attention_mask=attention_mask.contiguous(),
451
- cache_position=cache_position + step,
452
- )
453
- outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
454
-
455
- # decoder
456
- else:
457
- outputs = self.decoder(
458
- input_ids.contiguous(),
459
- attention_mask.contiguous(),
460
- cache_position=cache_position,
461
- )
462
-
463
- return CausalLMOutputWithPast(
464
- logits=outputs,
465
- past_key_values=past_key_values,
466
- )
467
-
468
- def forward_cb(
469
- self,
470
- input_ids: torch.LongTensor = None,
471
- cache_position: Optional[torch.Tensor] = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
472
- batch_idx: int = None,
473
- **kwargs,
474
- ) -> Tuple[torch.FloatTensor]:
475
- # prefill_decoder
476
- if cache_position.shape[1] > 1:
477
- query_length = input_ids.shape[1]
478
- attention_mask = self.prefill_attention_mask.clone()
479
- for step in range(0, query_length, self.prefill_chunk_size):
480
- if step + self.prefill_chunk_size > query_length:
481
- input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
482
- cache_position = torch.cat(
483
- [
484
- cache_position,
485
- torch.arange(
486
- query_length,
487
- step + self.prefill_chunk_size,
488
- dtype=torch.int32,
489
- ).unsqueeze(0),
490
- ],
491
- dim=-1,
492
- )
493
-
494
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
495
- sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
496
- attention_mask[:, :, :, :step] = 1
497
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
498
-
499
- outputs, _ = self.prefill_decoder(
500
- sliced_input_ids.contiguous(),
501
- attention_mask.contiguous(),
502
- sliced_cache_positions.contiguous(),
503
- torch.tensor(batch_idx, dtype=torch.int16),
504
- )
505
- outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
506
- # decoder
507
- else:
508
- attention_mask = self.decoder_attention_mask.clone()
509
- for b_idx in range(self.batch_size):
510
- attention_mask[b_idx, :, :, : cache_position[b_idx].item() + 1] = 1
511
-
512
- outputs = self.decoder(
513
- input_ids.contiguous(),
514
- attention_mask.contiguous(),
515
- cache_position.contiguous(),
516
- torch.tensor(0, dtype=torch.int16),
517
- )[0]
518
-
519
- return CausalLMOutputWithPast(
520
- logits=outputs,
521
- )
@@ -10,7 +10,8 @@
10
10
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
- """ Tokenization class for model Midm_bitext_tonkenizer."""
13
+ """Tokenization class for model Midm_bitext_tonkenizer."""
14
+
14
15
  import os
15
16
  import re
16
17
  import warnings
@@ -817,7 +817,6 @@ class MidmModel(MidmPreTrainedModel):
817
817
  all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
818
818
  all_hidden_states = () if output_hidden_states else None
819
819
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
820
-
821
820
  # Model parallel
822
821
  if self.model_parallel:
823
822
  torch.cuda.set_device(hidden_states.device)
@@ -833,7 +832,6 @@ class MidmModel(MidmPreTrainedModel):
833
832
  all_hidden_states = all_hidden_states + (hidden_states,)
834
833
 
835
834
  if self.gradient_checkpointing and self.training:
836
-
837
835
  if use_cache:
838
836
  logger.warning(
839
837
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -1174,7 +1172,6 @@ class MidmDoubleHeadsModel(MidmPreTrainedModel):
1174
1172
  return_dict=None,
1175
1173
  **kwargs,
1176
1174
  ):
1177
-
1178
1175
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1179
1176
 
1180
1177
  transformer_outputs = self.transformer(
@@ -1445,7 +1442,6 @@ def get_submodule(module, target: str): # -> "Module":
1445
1442
  mod: torch.nn.Module = module
1446
1443
 
1447
1444
  for item in atoms:
1448
-
1449
1445
  if not hasattr(mod, item):
1450
1446
  raise AttributeError(mod._get_name() + " has no " "attribute `" + item + "`")
1451
1447