optimum-rbln 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +7 -0
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
  5. optimum/rbln/diffusers/models/controlnet.py +93 -23
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
  7. optimum/rbln/diffusers/pipelines/__init__.py +7 -2
  8. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
  10. optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
  15. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
  18. optimum/rbln/modeling_base.py +39 -6
  19. optimum/rbln/modeling_seq2seq.py +19 -4
  20. optimum/rbln/transformers/__init__.py +2 -0
  21. optimum/rbln/transformers/generation/__init__.py +1 -0
  22. optimum/rbln/transformers/generation/streamers.py +17 -0
  23. optimum/rbln/transformers/generation/utils.py +399 -0
  24. optimum/rbln/transformers/models/__init__.py +1 -0
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +49 -17
  27. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +759 -0
  28. optimum/rbln/transformers/models/llama/modeling_llama.py +187 -75
  29. optimum/rbln/transformers/models/midm/__init__.py +32 -0
  30. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
  31. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
  32. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
  33. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
  34. optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
  35. optimum/rbln/transformers/models/midm/modeling_midm.py +426 -0
  36. optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
  37. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/METADATA +5 -4
  38. optimum_rbln-0.1.4.dist-info/RECORD +63 -0
  39. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/WHEEL +1 -1
  40. optimum_rbln-0.1.0.dist-info/RECORD +0 -51
  41. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -34,16 +34,25 @@ from optimum.exporters import TasksManager
34
34
  from transformers import AutoModelForCausalLM, LlamaForCausalLM, PretrainedConfig, AutoConfig
35
35
  from transformers.modeling_outputs import CausalLMOutputWithPast
36
36
 
37
+ from ...generation.utils import RBLNGenerationMixin
37
38
  from ....modeling_base import RBLNBaseModel
38
39
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
39
40
  from ....utils.runtime_utils import RBLNPytorchRuntime
40
41
  from ....utils.save_utils import maybe_save_preprocessors
42
+
43
+
44
+ # FIXME:: Merge Two architecture Codes
41
45
  from .llama_architecture import (
42
46
  LlamaWrapper,
43
47
  wrap_llama,
44
48
  unwrap_llama,
45
49
  )
46
50
 
51
+ from .llama_architecture_cb import (
52
+ LlamaDynamicBatchWrapper as LlamaWrapper_cb,
53
+ wrap_llama as wrap_llama_cb,
54
+ )
55
+
47
56
 
48
57
  logger = logging.getLogger(__name__)
49
58
 
@@ -56,26 +65,14 @@ if TYPE_CHECKING:
56
65
  )
57
66
 
58
67
 
68
+ SUPPORTED_BATCHING_MODES = ["static", "vllm"]
69
+
70
+
59
71
  class RBLNRuntimeModel(RBLNPytorchRuntime):
60
72
  mandatory_members = ["main_input_name"]
61
73
 
62
- # RBLN_Runtimemodule
63
- def forward(
64
- self,
65
- input_ids: torch.LongTensor = None,
66
- attention_mask: torch.LongTensor = None,
67
- cache_position: torch.Tensor = None,
68
- **kwargs: Dict[str, Any],
69
- ):
70
- logits = super().forward(
71
- input_ids=input_ids,
72
- attention_mask=attention_mask,
73
- cache_position=cache_position,
74
- )
75
- return logits
76
-
77
74
 
78
- class RBLNLlamaForCausalLM(RBLNBaseModel):
75
+ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
79
76
  """
80
77
  The Llama Model transformer with a language modeling head (linear layer) on top.
81
78
  This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -91,21 +88,24 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
91
88
  auto_model_class = AutoModelForCausalLM
92
89
 
93
90
  def __post_init__(self, **kwargs):
94
-
95
91
  self.batch_size = self.rbln_config.meta["rbln_batch_size"]
96
92
  self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
97
93
  self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
94
+ self.use_continuous_batch = self.rbln_config.meta["rbln_batching"] == "vllm"
98
95
 
96
+ prefill_batch_size = self.batch_size if not self.use_continuous_batch else 1
99
97
  self.prefill_attention_mask = torch.zeros(
100
- self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
98
+ prefill_batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
101
99
  )
102
100
  self.causal_mask = 1 - torch.triu(
103
- torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
101
+ torch.ones(prefill_batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
104
102
  )
103
+ self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
105
104
 
106
105
  self.prefill_decoder = RBLNRuntimeModel(runtime=self.runtimes[0], main_input_name="input_ids")
107
106
  self.decoder = RBLNRuntimeModel(runtime=self.runtimes[1], main_input_name="input_ids")
108
107
  self.past_cached_length = 0
108
+ self.right_padding = True
109
109
 
110
110
  @classmethod
111
111
  @torch.no_grad()
@@ -120,14 +120,23 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
120
120
  subfolder: str = "",
121
121
  local_files_only: bool = False,
122
122
  trust_remote_code: bool = False,
123
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
123
124
  **kwargs,
124
125
  ) -> "RBLNLlamaForCausalLM":
125
126
  task = kwargs.pop("task", None)
126
127
  if task is None:
127
128
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
128
129
 
129
- save_dir = TemporaryDirectory()
130
- save_dir_path = Path(save_dir.name)
130
+ if model_save_dir is None:
131
+ save_dir = TemporaryDirectory()
132
+ save_dir_path = Path(save_dir.name)
133
+ else:
134
+ save_dir = model_save_dir
135
+ if isinstance(save_dir, TemporaryDirectory):
136
+ save_dir_path = Path(model_save_dir.name)
137
+ else:
138
+ save_dir_path = Path(model_save_dir)
139
+ save_dir_path.mkdir(exist_ok=True)
131
140
 
132
141
  def update_configs(kwargs):
133
142
  hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
@@ -154,7 +163,10 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
154
163
 
155
164
  rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
156
165
 
157
- origin_mehtods = wrap_llama()
166
+ # FIXME :: This should be moved when wrapping removed.
167
+ use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
168
+ origin_mehtods = wrap_llama_cb() if use_continuous_batch else wrap_llama()
169
+
158
170
  model: LlamaForCausalLM = TasksManager.get_model_from_task(
159
171
  task=task,
160
172
  model_name_or_path=model_id,
@@ -181,14 +193,18 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
181
193
  preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
182
194
  )
183
195
 
184
- def compile_llama():
185
- wrapped_model = LlamaWrapper(model).eval()
196
+ def compile_llama(use_continuous_batch, wrapper_cls):
197
+ wrapped_model = wrapper_cls(model).eval()
186
198
 
187
199
  prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
188
200
  dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
189
201
 
190
202
  prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
191
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
203
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
204
+
205
+ if use_continuous_batch:
206
+ batch_index_index = 3
207
+ dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
192
208
 
193
209
  prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs)
194
210
  dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs)
@@ -203,8 +219,9 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
203
219
  )
204
220
 
205
221
  # Caching prefill_decoder/decoder I/O
222
+ cache_index_offset = 4 if use_continuous_batch else 3
206
223
  connections = [
207
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i])
224
+ (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
208
225
  for i in range(model.config.num_hidden_layers * 2)
209
226
  ]
210
227
 
@@ -219,7 +236,8 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
219
236
  )
220
237
  compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
221
238
 
222
- compile_llama()
239
+ wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
240
+ compile_llama(use_continuous_batch=use_continuous_batch, wrapper_cls=wrapper_cls)
223
241
  unwrap_llama(origin_mehtods)
224
242
 
225
243
  rbln_config.save(save_dir_path)
@@ -239,27 +257,46 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
239
257
  model_config: "PretrainedConfig",
240
258
  rbln_max_seq_len: Optional[int] = None,
241
259
  rbln_batch_size: Optional[int] = None,
260
+ rbln_batching: Optional[str] = None,
242
261
  ) -> RBLNConfig:
243
262
  meta = {}
244
263
 
245
264
  prefill_chunk_size = 128
246
265
  if rbln_max_seq_len is None:
247
266
  rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
267
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
268
+ rbln_batching = "static" if rbln_batching is None else rbln_batching
248
269
 
249
270
  meta["rbln_max_seq_len"] = rbln_max_seq_len
250
271
  meta["rbln_batch_size"] = rbln_batch_size
251
272
  meta["rbln_prefill_chunk_size"] = prefill_chunk_size
273
+ meta["rbln_batching"] = rbln_batching
274
+ use_continuous_batching = meta["rbln_batching"] == "vllm"
252
275
 
253
- def get_input_info(query_length):
276
+ if rbln_batching not in SUPPORTED_BATCHING_MODES:
277
+ raise ValueError(
278
+ f'rbln_batching="{rbln_batching}" is not a supported batch mode, '
279
+ f"Possible: {SUPPORTED_BATCHING_MODES}"
280
+ )
281
+
282
+ def get_input_info(
283
+ batch_size, # should be 1 if continous batch prefill
284
+ query_length,
285
+ continuous_batch=False, # determines the shape of `cache position`
286
+ ):
254
287
  input_info = [
255
- ("input_ids", [rbln_batch_size, query_length], "int64"),
256
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
288
+ ("input_ids", [batch_size, query_length], "int64"),
289
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
257
290
  (
258
291
  "cache_position",
259
- [],
292
+ [batch_size, query_length] if continuous_batch else [],
260
293
  "int32",
261
294
  ),
262
295
  ]
296
+
297
+ if continuous_batch:
298
+ input_info.append(("batch_position", [], "int16"))
299
+
263
300
  input_info.extend(
264
301
  [
265
302
  (
@@ -275,10 +312,19 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
275
312
  for i in range(model_config.num_hidden_layers * 2)
276
313
  ]
277
314
  )
315
+
278
316
  return input_info
279
317
 
280
- prefill_input_info = get_input_info(query_length=prefill_chunk_size)
281
- dec_input_info = get_input_info(query_length=1)
318
+ prefill_input_info = get_input_info(
319
+ batch_size=1 if use_continuous_batching else rbln_batch_size,
320
+ query_length=prefill_chunk_size,
321
+ continuous_batch=use_continuous_batching,
322
+ )
323
+ dec_input_info = get_input_info(
324
+ batch_size=rbln_batch_size,
325
+ query_length=1,
326
+ continuous_batch=use_continuous_batching,
327
+ )
282
328
 
283
329
  prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
284
330
  dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
@@ -321,23 +367,46 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
321
367
 
322
368
  # args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
323
369
  def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
324
- batch_size, hf_input_length = input_ids.shape
370
+ batch_size, cur_len = input_ids.shape
325
371
  past_cached_length = past_key_values
326
- query_length = hf_input_length - past_cached_length
327
372
 
328
373
  # In greedy decoding
329
- if past_key_values == 0:
330
- self.prompt_length = query_length
331
- self.prompt_ids = input_ids
332
- self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
333
-
334
- attention_mask = torch.zeros(batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64)
374
+ if past_cached_length == 0:
375
+ # padding with prefill_chunk_size
376
+ # TODO left padding + left padding has issue on stoppingcriteria(max_len)
377
+ if cur_len % self.prefill_chunk_size != 0:
378
+ pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
379
+ input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
380
+
381
+ # padding_side
382
+ if batch_size > 1 and torch.all(attention_mask[..., -1] == 1):
383
+ self.right_padding = False
384
+
385
+ if self.right_padding:
386
+ self.rightpad_max_len = cur_len
387
+ prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
388
+ self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len # dummy_decoder generation length
389
+ query_length = prompt_min_len.item()
390
+ else:
391
+ query_length = cur_len - past_cached_length
392
+ self.prompt_length = query_length
393
+ self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
394
+
395
+ attention_mask = self.prefill_attention_mask.clone()
335
396
  cache_position = torch.tensor(0, dtype=torch.int32)
397
+
336
398
  else:
337
- attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - hf_input_length))
338
- attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
399
+ if self.right_padding:
400
+ attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
401
+ attention_mask[:, :, :, : past_cached_length + 1] = 1
402
+ input_ids = input_ids[:, past_cached_length : past_cached_length + 1].contiguous()
403
+ else:
404
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - cur_len))
405
+ attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
406
+ input_ids = input_ids[:, -1:]
407
+
339
408
  cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
340
- input_ids = input_ids[:, -1:]
409
+ query_length = 1
341
410
 
342
411
  model_inputs = {
343
412
  "input_ids": input_ids,
@@ -349,7 +418,13 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
349
418
 
350
419
  return model_inputs
351
420
 
352
- def forward(
421
+ def forward(self, *args, **kwargs):
422
+ if self.use_continuous_batch:
423
+ return self.forward_cb(*args, **kwargs)
424
+ else:
425
+ return self.forward_static(*args, **kwargs)
426
+
427
+ def forward_static(
353
428
  self,
354
429
  input_ids: torch.LongTensor = None,
355
430
  attention_mask: Optional[torch.Tensor] = None,
@@ -363,38 +438,20 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
363
438
 
364
439
  # prefill_decoder
365
440
  if cache_position == 0:
366
- while query_length > self.prefill_chunk_size:
367
- # prepare input_ids & attention_mask
368
- sliced_input_ids = input_ids[:, cache_position : cache_position + self.prefill_chunk_size].contiguous()
369
- attention_mask[:, :, :, :cache_position] = 1
370
- attention_mask[:, :, :, cache_position : cache_position + self.prefill_chunk_size] = self.causal_mask
371
- attention_mask[:, :, :, : self.prompt_length] *= self.prompt_attn_mask[:, :, :, :]
372
-
373
- _ = self.prefill_decoder(
374
- sliced_input_ids,
375
- attention_mask,
376
- cache_position,
441
+ for step in range(0, query_length, self.prefill_chunk_size):
442
+ sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
443
+ attention_mask[:, :, :, :step] = 1
444
+ attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
445
+ if not self.right_padding:
446
+ attention_mask[:, :, :, : self.prompt_length] &= self.prompt_attn_mask[:, :, :, :]
447
+
448
+ outputs = self.prefill_decoder(
449
+ input_ids=sliced_input_ids.contiguous(),
450
+ attention_mask=attention_mask.contiguous(),
451
+ cache_position=cache_position + step,
377
452
  )
378
- # update query_length & cache_position
379
- query_length -= self.prefill_chunk_size
380
- cache_position += self.prefill_chunk_size
381
-
382
- # prepare input_ids & attention_mask
383
- last_input_ids = input_ids[:, cache_position : cache_position + query_length]
384
- last_input_ids = torch.nn.functional.pad(last_input_ids, (0, self.prefill_chunk_size - query_length))
453
+ outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
385
454
 
386
- attention_mask[:, :, :, :cache_position] = 1
387
- mask_slice = self.causal_mask[:, :, :query_length, :query_length]
388
- attention_mask[:, :, :query_length, cache_position : cache_position + query_length] = mask_slice
389
- attention_mask[:, :, :, : self.prompt_length] *= self.prompt_attn_mask[:, :, :, :]
390
-
391
- outputs = self.prefill_decoder(
392
- last_input_ids.contiguous(),
393
- attention_mask.contiguous(),
394
- cache_position,
395
- )
396
-
397
- outputs = outputs[:, query_length - 1].unsqueeze(1)
398
455
  # decoder
399
456
  else:
400
457
  outputs = self.decoder(
@@ -407,3 +464,58 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
407
464
  logits=outputs,
408
465
  past_key_values=past_key_values,
409
466
  )
467
+
468
+ def forward_cb(
469
+ self,
470
+ input_ids: torch.LongTensor = None,
471
+ cache_position: Optional[torch.Tensor] = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
472
+ batch_idx: int = None,
473
+ **kwargs,
474
+ ) -> Tuple[torch.FloatTensor]:
475
+ # prefill_decoder
476
+ if cache_position.shape[1] > 1:
477
+ query_length = input_ids.shape[1]
478
+ attention_mask = self.prefill_attention_mask.clone()
479
+ for step in range(0, query_length, self.prefill_chunk_size):
480
+ if step + self.prefill_chunk_size > query_length:
481
+ input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
482
+ cache_position = torch.cat(
483
+ [
484
+ cache_position,
485
+ torch.arange(
486
+ query_length,
487
+ step + self.prefill_chunk_size,
488
+ dtype=torch.int32,
489
+ ).unsqueeze(0),
490
+ ],
491
+ dim=-1,
492
+ )
493
+
494
+ sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
495
+ sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
496
+ attention_mask[:, :, :, :step] = 1
497
+ attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
498
+
499
+ outputs, _ = self.prefill_decoder(
500
+ sliced_input_ids.contiguous(),
501
+ attention_mask.contiguous(),
502
+ sliced_cache_positions.contiguous(),
503
+ torch.tensor(batch_idx, dtype=torch.int16),
504
+ )
505
+ outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
506
+ # decoder
507
+ else:
508
+ attention_mask = self.decoder_attention_mask.clone()
509
+ for b_idx in range(self.batch_size):
510
+ attention_mask[b_idx, :, :, : cache_position[b_idx].item() + 1] = 1
511
+
512
+ outputs = self.decoder(
513
+ input_ids.contiguous(),
514
+ attention_mask.contiguous(),
515
+ cache_position.contiguous(),
516
+ torch.tensor(0, dtype=torch.int16),
517
+ )[0]
518
+
519
+ return CausalLMOutputWithPast(
520
+ logits=outputs,
521
+ )
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import os
25
+ from os import environ
26
+
27
+
28
+ this_path = os.path.abspath(__file__)
29
+ local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
30
+ environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
31
+
32
+ from .modeling_midm import RBLNMidmLMHeadModel
@@ -0,0 +1,22 @@
1
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
2
+
3
+
4
+ class MidmBitextConfig(GPT2Config):
5
+ model_type = "midm-bitext-S"
6
+
7
+ def __init__(
8
+ self,
9
+ use_absolute_position_embedding: bool = True,
10
+ use_rotary_position_embedding: bool = False,
11
+ rotary_percentage: float = 1.0,
12
+ normalization_type: str = "layernorm",
13
+ scale_qk_by_inverse_layer_idx: bool = False,
14
+ *args,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(*args, **kwargs)
18
+ self.use_absolute_position_embedding = use_absolute_position_embedding
19
+ self.use_rotary_position_embedding = use_rotary_position_embedding
20
+ self.rotary_percentage = rotary_percentage
21
+ self.normalization_type = normalization_type
22
+ self.scale_qk_by_inverse_layer_idx = scale_qk_by_inverse_layer_idx