optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,54 +21,20 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect # noqa: I001
24
+ import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import Any, Callable
27
27
 
28
- import torch # noqa: F401
29
- import rebel # noqa: F401
28
+ from transformers import LlamaForCausalLM, PreTrainedModel
30
29
 
31
- from transformers import AutoModelForCausalLM, LlamaForCausalLM, PreTrainedModel, PretrainedConfig, AutoConfig
32
- from transformers.modeling_outputs import CausalLMOutputWithPast
33
-
34
- from ...generation.utils import RBLNGenerationMixin
35
- from ....modeling_base import RBLNModel
36
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
37
- from ....utils.runtime_utils import RBLNPytorchRuntime
38
-
39
-
40
- # FIXME:: Merge Two architecture Codes
41
- from .llama_architecture import (
42
- LlamaWrapper,
43
- wrap_llama,
44
- unwrap_llama,
45
- )
46
-
47
- from .llama_architecture_cb import (
48
- LlamaDynamicBatchWrapper as LlamaWrapper_cb,
49
- wrap_llama as wrap_llama_cb,
50
- )
30
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
+ from .llama_architecture import LlamaWrapper
51
32
 
52
33
 
53
34
  logger = logging.getLogger(__name__)
54
35
 
55
- if TYPE_CHECKING:
56
- from transformers import (
57
- AutoFeatureExtractor,
58
- AutoProcessor,
59
- AutoTokenizer,
60
- PretrainedConfig,
61
- )
62
-
63
-
64
- SUPPORTED_BATCHING_MODES = ["static", "vllm"]
65
-
66
-
67
- class RBLNRuntimeModel(RBLNPytorchRuntime):
68
- mandatory_members = ["main_input_name"]
69
-
70
36
 
71
- class RBLNLlamaForCausalLM(RBLNModel, RBLNGenerationMixin):
37
+ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
72
38
  """
73
39
  The Llama Model transformer with a language modeling head (linear layer) on top.
74
40
  This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -79,251 +45,9 @@ class RBLNLlamaForCausalLM(RBLNModel, RBLNGenerationMixin):
79
45
  - compiling the resulting graph using the RBLN compiler.
80
46
  """
81
47
 
82
- main_input_name = "input_ids"
83
- auto_model_class = AutoModelForCausalLM
84
-
85
- def __post_init__(self, **kwargs):
86
- self.batch_size = self.rbln_config.meta["rbln_batch_size"]
87
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
88
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
89
- self.use_continuous_batch = self.rbln_config.meta["rbln_batching"] == "vllm"
90
-
91
- prefill_batch_size = self.batch_size if not self.use_continuous_batch else 1
92
- self.prefill_attention_mask = torch.zeros(
93
- prefill_batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
94
- )
95
- self.causal_mask = 1 - torch.triu(
96
- torch.ones(prefill_batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
97
- )
98
- self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
99
-
100
- self.prefill_decoder = RBLNRuntimeModel(runtime=self.model[0], main_input_name="input_ids")
101
- self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
102
- self.past_cached_length = 0
103
- self.right_padding = True
104
-
105
- @classmethod
106
- def update_kwargs(cls, kwargs):
107
- """
108
- Update user-given kwargs to get proper pytorch model.
109
-
110
- For example, `torchscript`=True should be set because torch.jit
111
- does not support `transformers` output instances as module output;
112
- """
113
- kwargs.update(
114
- {
115
- "torchscript": True,
116
- "return_dict": False,
117
- "use_cache": True,
118
- "torch_dtype": torch.float32,
119
- "_attn_implementation": "eager",
120
- }
121
- )
122
- return kwargs
123
-
124
- @classmethod
125
- def get_pytorch_model(
126
- cls,
127
- model_id: str,
128
- use_auth_token: Optional[Union[bool, str]] = None,
129
- revision: Optional[str] = None,
130
- force_download: bool = False,
131
- cache_dir: Optional[str] = None,
132
- subfolder: str = "",
133
- local_files_only: bool = False,
134
- trust_remote_code: bool = False,
135
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
136
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
137
- **kwargs,
138
- ) -> PreTrainedModel:
139
- if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
140
- config = AutoConfig.from_pretrained(model_id)
141
- if hf_position_embedding := getattr(config, "max_position_embeddings", None):
142
- if hf_position_embedding < rbln_max_seq_len:
143
- logger.warning(
144
- f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
145
- "This may lead to incorrect inferences of the model."
146
- )
147
- kwargs.update({"max_position_embeddings": rbln_max_seq_len})
148
-
149
- # FIXME :: This should be moved when wrapping removed.
150
- use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
151
- wrap_llama_cb() if use_continuous_batch else wrap_llama()
152
-
153
- model = super().get_pytorch_model(
154
- model_id=model_id,
155
- use_auth_token=use_auth_token,
156
- revision=revision,
157
- force_download=force_download,
158
- cache_dir=cache_dir,
159
- subfolder=subfolder,
160
- local_files_only=local_files_only,
161
- trust_remote_code=trust_remote_code,
162
- rbln_config_kwargs=rbln_config_kwargs,
163
- rbln_constructor_kwargs=rbln_constructor_kwargs,
164
- **kwargs,
165
- )
166
-
167
- unwrap_llama()
168
-
169
- return model
170
-
171
- @classmethod
172
- @torch.inference_mode()
173
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
174
- use_continuous_batch = rbln_config.meta["rbln_batching"] == "vllm"
175
-
176
- wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
177
-
178
- wrapped_model = wrapper_cls(model).eval()
179
-
180
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
181
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
182
-
183
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
184
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
185
-
186
- if use_continuous_batch:
187
- batch_index_index = 3
188
- dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
189
-
190
- wrap_llama_cb() if use_continuous_batch else wrap_llama()
191
-
192
- prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs, check_trace=False)
193
- dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
194
-
195
- unwrap_llama()
196
-
197
- prefill_ir = rebel.torchscript_to_ir(
198
- prefill_scripted_model,
199
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
200
- )
201
- dec_ir = rebel.torchscript_to_ir(
202
- dec_scripted_model,
203
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
204
- )
205
-
206
- # Caching prefill_decoder/decoder I/O
207
- cache_index_offset = 4 if use_continuous_batch else 3
208
- connections = [
209
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
210
- for i in range(model.config.num_hidden_layers * 2)
211
- ]
212
-
213
- compiled_model = rebel.compile(
214
- prefill_ir,
215
- dec_ir,
216
- connections=connections,
217
- fusion=prefill_rbln_runtime_config.fusion,
218
- npu=prefill_rbln_runtime_config.npu,
219
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
220
- use_weight_sharing=True,
221
- )
222
- return compiled_model
223
-
224
- @classmethod
225
- def _get_rbln_config(
226
- cls,
227
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
228
- model_config: "PretrainedConfig",
229
- rbln_max_seq_len: Optional[int] = None,
230
- rbln_batch_size: Optional[int] = None,
231
- rbln_batching: Optional[str] = None,
232
- ) -> RBLNConfig:
233
- meta = {}
234
-
235
- prefill_chunk_size = 128
236
- if rbln_max_seq_len is None:
237
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
238
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
239
- rbln_batching = "static" if rbln_batching is None else rbln_batching
240
-
241
- meta["rbln_max_seq_len"] = rbln_max_seq_len
242
- meta["rbln_batch_size"] = rbln_batch_size
243
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
244
- meta["rbln_batching"] = rbln_batching
245
- use_continuous_batching = meta["rbln_batching"] == "vllm"
246
-
247
- if rbln_batching not in SUPPORTED_BATCHING_MODES:
248
- raise ValueError(
249
- f'rbln_batching="{rbln_batching}" is not a supported batch mode, '
250
- f"Possible: {SUPPORTED_BATCHING_MODES}"
251
- )
252
-
253
- def get_input_info(
254
- batch_size, # should be 1 if continous batch prefill
255
- query_length,
256
- continuous_batch=False, # determines the shape of `cache position`
257
- ):
258
- input_info = [
259
- ("input_ids", [batch_size, query_length], "int64"),
260
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
261
- (
262
- "cache_position",
263
- [batch_size, query_length] if continuous_batch else [],
264
- "int32",
265
- ),
266
- ]
267
-
268
- if continuous_batch:
269
- input_info.append(("batch_position", [], "int16"))
270
-
271
- input_info.extend(
272
- [
273
- (
274
- f"past_key_values_{i}",
275
- [
276
- rbln_batch_size,
277
- model_config.num_key_value_heads,
278
- rbln_max_seq_len,
279
- model_config.hidden_size // model_config.num_attention_heads,
280
- ],
281
- "float32",
282
- )
283
- for i in range(model_config.num_hidden_layers * 2)
284
- ]
285
- )
286
-
287
- return input_info
288
-
289
- prefill_input_info = get_input_info(
290
- batch_size=1 if use_continuous_batching else rbln_batch_size,
291
- query_length=prefill_chunk_size,
292
- continuous_batch=use_continuous_batching,
293
- )
294
- dec_input_info = get_input_info(
295
- batch_size=rbln_batch_size,
296
- query_length=1,
297
- continuous_batch=use_continuous_batching,
298
- )
299
-
300
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
301
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
302
-
303
- dec_rbln_runtime_config.batch_size = rbln_batch_size
304
-
305
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
306
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
307
- _rbln_meta=meta,
308
- )
309
-
310
- return rbln_config
311
-
312
48
  @classmethod
313
- def _create_runtimes(
314
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
315
- ) -> List[rebel.Runtime]:
316
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
317
- return [
318
- compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
319
- compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
320
- ]
321
-
322
- def get_decoder(self):
323
- return self.decoder
324
-
325
- def can_generate(self):
326
- return True
49
+ def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
50
+ return LlamaWrapper(model, rbln_max_seq_len).eval()
327
51
 
328
52
  def __getattr__(self, __name: str) -> Any:
329
53
  def redirect(func):
@@ -335,161 +59,3 @@ class RBLNLlamaForCausalLM(RBLNModel, RBLNGenerationMixin):
335
59
  return redirect(val)
336
60
 
337
61
  return val
338
-
339
- def _reorder_cache(self, past_key_values, beam_idx):
340
- raise NotImplementedError
341
-
342
- # args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
343
- def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
344
- batch_size, cur_len = input_ids.shape
345
- past_cached_length = past_key_values
346
-
347
- # In greedy decoding
348
- if past_cached_length == 0:
349
- # padding with prefill_chunk_size
350
- # TODO left padding + left padding has issue on stoppingcriteria(max_len)
351
- if cur_len % self.prefill_chunk_size != 0:
352
- pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
353
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
354
-
355
- # padding_side
356
- if batch_size > 1 and torch.all(attention_mask[..., -1] == 1):
357
- self.right_padding = False
358
-
359
- if self.right_padding:
360
- self.rightpad_max_len = cur_len
361
- prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
362
- self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len # dummy_decoder generation length
363
- query_length = prompt_min_len.item()
364
- else:
365
- query_length = cur_len - past_cached_length
366
- self.prompt_length = query_length
367
- self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
368
-
369
- attention_mask = self.prefill_attention_mask.clone()
370
- cache_position = torch.tensor(0, dtype=torch.int32)
371
-
372
- else:
373
- if self.right_padding:
374
- attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
375
- attention_mask[:, :, :, : past_cached_length + 1] = 1
376
- input_ids = input_ids[:, past_cached_length : past_cached_length + 1].contiguous()
377
- else:
378
- attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - cur_len))
379
- attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
380
- input_ids = input_ids[:, -1:]
381
-
382
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
383
- query_length = 1
384
-
385
- model_inputs = {
386
- "input_ids": input_ids,
387
- "past_key_values": past_key_values,
388
- "attention_mask": attention_mask,
389
- "cache_position": cache_position,
390
- "query_length": query_length,
391
- }
392
-
393
- return model_inputs
394
-
395
- def forward(self, *args, **kwargs):
396
- if self.use_continuous_batch:
397
- return self.forward_cb(*args, **kwargs)
398
- else:
399
- return self.forward_static(*args, **kwargs)
400
-
401
- def forward_static(
402
- self,
403
- input_ids: torch.LongTensor = None,
404
- attention_mask: Optional[torch.Tensor] = None,
405
- past_key_values: int = None,
406
- cache_position: Optional[torch.Tensor] = None,
407
- query_length: Optional[torch.Tensor] = None,
408
- **kwargs,
409
- ) -> Tuple[torch.FloatTensor]:
410
- if past_key_values is not None:
411
- past_key_values += query_length
412
-
413
- # prefill_decoder
414
- if cache_position == 0:
415
- for step in range(0, query_length, self.prefill_chunk_size):
416
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
417
- attention_mask[:, :, :, :step] = 1
418
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
419
- if not self.right_padding:
420
- attention_mask[:, :, :, : self.prompt_length] &= self.prompt_attn_mask[:, :, :, :]
421
-
422
- outputs = self.prefill_decoder(
423
- input_ids=sliced_input_ids.contiguous(),
424
- attention_mask=attention_mask.contiguous(),
425
- cache_position=cache_position + step,
426
- )
427
- outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
428
-
429
- # decoder
430
- else:
431
- outputs = self.decoder(
432
- input_ids.contiguous(),
433
- attention_mask.contiguous(),
434
- cache_position=cache_position,
435
- )
436
-
437
- return CausalLMOutputWithPast(
438
- logits=outputs,
439
- past_key_values=past_key_values,
440
- )
441
-
442
- def forward_cb(
443
- self,
444
- input_ids: torch.LongTensor = None,
445
- cache_position: Optional[torch.Tensor] = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
446
- batch_idx: int = None,
447
- **kwargs,
448
- ) -> Tuple[torch.FloatTensor]:
449
- # prefill_decoder
450
- if cache_position.shape[1] > 1:
451
- query_length = input_ids.shape[1]
452
- attention_mask = self.prefill_attention_mask.clone()
453
- for step in range(0, query_length, self.prefill_chunk_size):
454
- if step + self.prefill_chunk_size > query_length:
455
- input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
456
- cache_position = torch.cat(
457
- [
458
- cache_position,
459
- torch.arange(
460
- query_length,
461
- step + self.prefill_chunk_size,
462
- dtype=torch.int32,
463
- ).unsqueeze(0),
464
- ],
465
- dim=-1,
466
- )
467
-
468
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
469
- sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
470
- attention_mask[:, :, :, :step] = 1
471
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
472
-
473
- outputs, _ = self.prefill_decoder(
474
- sliced_input_ids.contiguous(),
475
- attention_mask.contiguous(),
476
- sliced_cache_positions.contiguous(),
477
- torch.tensor(batch_idx, dtype=torch.int16),
478
- )
479
- outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
480
- # decoder
481
- else:
482
- attention_mask = self.decoder_attention_mask.clone()
483
- for b_idx in range(self.batch_size):
484
- attention_mask[b_idx, :, :, : cache_position[b_idx].item() + 1] = 1
485
-
486
- outputs = self.decoder(
487
- input_ids.contiguous(),
488
- attention_mask.contiguous(),
489
- cache_position.contiguous(),
490
- torch.tensor(0, dtype=torch.int16),
491
- )[0]
492
-
493
- return CausalLMOutputWithPast(
494
- logits=outputs,
495
- )
@@ -10,7 +10,8 @@
10
10
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
- """ Tokenization class for model Midm_bitext_tonkenizer."""
13
+ """Tokenization class for model Midm_bitext_tonkenizer."""
14
+
14
15
  import os
15
16
  import re
16
17
  import warnings
@@ -817,7 +817,6 @@ class MidmModel(MidmPreTrainedModel):
817
817
  all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
818
818
  all_hidden_states = () if output_hidden_states else None
819
819
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
820
-
821
820
  # Model parallel
822
821
  if self.model_parallel:
823
822
  torch.cuda.set_device(hidden_states.device)
@@ -833,7 +832,6 @@ class MidmModel(MidmPreTrainedModel):
833
832
  all_hidden_states = all_hidden_states + (hidden_states,)
834
833
 
835
834
  if self.gradient_checkpointing and self.training:
836
-
837
835
  if use_cache:
838
836
  logger.warning(
839
837
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -1174,7 +1172,6 @@ class MidmDoubleHeadsModel(MidmPreTrainedModel):
1174
1172
  return_dict=None,
1175
1173
  **kwargs,
1176
1174
  ):
1177
-
1178
1175
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1179
1176
 
1180
1177
  transformer_outputs = self.transformer(
@@ -1445,7 +1442,6 @@ def get_submodule(module, target: str): # -> "Module":
1445
1442
  mod: torch.nn.Module = module
1446
1443
 
1447
1444
  for item in atoms:
1448
-
1449
1445
  if not hasattr(mod, item):
1450
1446
  raise AttributeError(mod._get_name() + " has no " "attribute `" + item + "`")
1451
1447