optimum-rbln 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. optimum/rbln/__init__.py +9 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  4. optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
  5. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
  8. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  9. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  10. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  12. optimum/rbln/modeling_base.py +175 -103
  13. optimum/rbln/modeling_seq2seq.py +58 -132
  14. optimum/rbln/transformers/__init__.py +4 -0
  15. optimum/rbln/transformers/models/__init__.py +2 -0
  16. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  17. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  18. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
  20. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
  21. optimum/rbln/transformers/models/llama/llama_architecture.py +62 -33
  22. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +764 -0
  23. optimum/rbln/transformers/models/llama/modeling_llama.py +208 -140
  24. optimum/rbln/transformers/models/midm/__init__.py +32 -0
  25. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
  26. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
  27. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
  29. optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
  30. optimum/rbln/transformers/models/midm/modeling_midm.py +390 -0
  31. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  32. optimum/rbln/utils/__init__.py +1 -1
  33. optimum/rbln/utils/import_utils.py +46 -0
  34. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -50
  35. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +37 -27
  36. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
  37. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -23,28 +23,32 @@
23
23
 
24
24
  import inspect # noqa: I001
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
27
 
30
28
  import torch # noqa: F401
31
29
  import rebel # noqa: F401
32
30
 
33
- from optimum.exporters import TasksManager
34
- from transformers import AutoModelForCausalLM, LlamaForCausalLM, PretrainedConfig, AutoConfig
31
+ from transformers import AutoModelForCausalLM, LlamaForCausalLM, PreTrainedModel, PretrainedConfig, AutoConfig
35
32
  from transformers.modeling_outputs import CausalLMOutputWithPast
36
33
 
37
34
  from ...generation.utils import RBLNGenerationMixin
38
- from ....modeling_base import RBLNBaseModel
35
+ from ....modeling_base import RBLNModel
39
36
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
40
37
  from ....utils.runtime_utils import RBLNPytorchRuntime
41
- from ....utils.save_utils import maybe_save_preprocessors
38
+
39
+
40
+ # FIXME:: Merge Two architecture Codes
42
41
  from .llama_architecture import (
43
42
  LlamaWrapper,
44
43
  wrap_llama,
45
44
  unwrap_llama,
46
45
  )
47
46
 
47
+ from .llama_architecture_cb import (
48
+ LlamaDynamicBatchWrapper as LlamaWrapper_cb,
49
+ wrap_llama as wrap_llama_cb,
50
+ )
51
+
48
52
 
49
53
  logger = logging.getLogger(__name__)
50
54
 
@@ -57,29 +61,17 @@ if TYPE_CHECKING:
57
61
  )
58
62
 
59
63
 
64
+ SUPPORTED_BATCHING_MODES = ["static", "vllm"]
65
+
66
+
60
67
  class RBLNRuntimeModel(RBLNPytorchRuntime):
61
68
  mandatory_members = ["main_input_name"]
62
69
 
63
- # RBLN_Runtimemodule
64
- def forward(
65
- self,
66
- input_ids: torch.LongTensor = None,
67
- attention_mask: torch.LongTensor = None,
68
- cache_position: torch.Tensor = None,
69
- **kwargs: Dict[str, Any],
70
- ):
71
- logits = super().forward(
72
- input_ids=input_ids,
73
- attention_mask=attention_mask,
74
- cache_position=cache_position,
75
- )
76
- return logits
77
-
78
70
 
79
- class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
71
+ class RBLNLlamaForCausalLM(RBLNModel, RBLNGenerationMixin):
80
72
  """
81
73
  The Llama Model transformer with a language modeling head (linear layer) on top.
82
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
74
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
83
75
 
84
76
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
85
77
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -87,7 +79,6 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
87
79
  - compiling the resulting graph using the RBLN compiler.
88
80
  """
89
81
 
90
- model_type = "rbln_model"
91
82
  main_input_name = "input_ids"
92
83
  auto_model_class = AutoModelForCausalLM
93
84
 
@@ -95,25 +86,45 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
95
86
  self.batch_size = self.rbln_config.meta["rbln_batch_size"]
96
87
  self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
97
88
  self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
89
+ self.use_continuous_batch = self.rbln_config.meta["rbln_batching"] == "vllm"
98
90
 
91
+ prefill_batch_size = self.batch_size if not self.use_continuous_batch else 1
99
92
  self.prefill_attention_mask = torch.zeros(
100
- self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
93
+ prefill_batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
101
94
  )
102
95
  self.causal_mask = 1 - torch.triu(
103
- torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
96
+ torch.ones(prefill_batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
104
97
  )
98
+ self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
105
99
 
106
- self.prefill_decoder = RBLNRuntimeModel(runtime=self.runtimes[0], main_input_name="input_ids")
107
- self.decoder = RBLNRuntimeModel(runtime=self.runtimes[1], main_input_name="input_ids")
100
+ self.prefill_decoder = RBLNRuntimeModel(runtime=self.model[0], main_input_name="input_ids")
101
+ self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
108
102
  self.past_cached_length = 0
109
103
  self.right_padding = True
110
104
 
111
105
  @classmethod
112
- @torch.no_grad()
113
- def _export(
106
+ def update_kwargs(cls, kwargs):
107
+ """
108
+ Update user-given kwargs to get proper pytorch model.
109
+
110
+ For example, `torchscript`=True should be set because torch.jit
111
+ does not support `transformers` output instances as module output;
112
+ """
113
+ kwargs.update(
114
+ {
115
+ "torchscript": True,
116
+ "return_dict": False,
117
+ "use_cache": True,
118
+ "torch_dtype": torch.float32,
119
+ "_attn_implementation": "eager",
120
+ }
121
+ )
122
+ return kwargs
123
+
124
+ @classmethod
125
+ def get_pytorch_model(
114
126
  cls,
115
127
  model_id: str,
116
- config: "PretrainedConfig",
117
128
  use_auth_token: Optional[Union[bool, str]] = None,
118
129
  revision: Optional[str] = None,
119
130
  force_download: bool = False,
@@ -121,126 +132,94 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
121
132
  subfolder: str = "",
122
133
  local_files_only: bool = False,
123
134
  trust_remote_code: bool = False,
124
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
135
+ rbln_config_kwargs: Optional[Dict[str, Any]] = None,
136
+ rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
125
137
  **kwargs,
126
- ) -> "RBLNLlamaForCausalLM":
127
- task = kwargs.pop("task", None)
128
- if task is None:
129
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
130
-
131
- if model_save_dir is None:
132
- save_dir = TemporaryDirectory()
133
- save_dir_path = Path(save_dir.name)
134
- else:
135
- save_dir = model_save_dir
136
- if isinstance(save_dir, TemporaryDirectory):
137
- save_dir_path = Path(model_save_dir.name)
138
- else:
139
- save_dir_path = Path(model_save_dir)
140
- save_dir_path.mkdir(exist_ok=True)
141
-
142
- def update_configs(kwargs):
143
- hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
144
- max_seq_len = kwargs.get("rbln_max_seq_len", None)
145
- if max_seq_len is not None:
146
- if max_seq_len <= hf_max_position_embeddings:
147
- kwargs.update({"max_position_embeddings": max_seq_len})
148
- else:
149
- raise ValueError("`max_seq_len` should be less or equal than max_position_embeddings!")
150
-
151
- kwargs.update(
152
- {
153
- "torchscript": True,
154
- "return_dict": False,
155
- "use_cache": True,
156
- "torch_dtype": torch.float32,
157
- "_attn_implementation": "eager",
158
- }
159
- )
160
-
161
- return kwargs
162
-
163
- kwargs = update_configs(kwargs)
138
+ ) -> PreTrainedModel:
139
+ if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
140
+ config = AutoConfig.from_pretrained(model_id)
141
+ if hf_position_embedding := getattr(config, "max_position_embeddings", None):
142
+ if hf_position_embedding < rbln_max_seq_len:
143
+ logger.warning(
144
+ f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
145
+ "This may lead to incorrect inferences of the model."
146
+ )
147
+ kwargs.update({"max_position_embeddings": rbln_max_seq_len})
164
148
 
165
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
149
+ # FIXME :: This should be moved when wrapping removed.
150
+ use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
151
+ wrap_llama_cb() if use_continuous_batch else wrap_llama()
166
152
 
167
- origin_mehtods = wrap_llama()
168
- model: LlamaForCausalLM = TasksManager.get_model_from_task(
169
- task=task,
170
- model_name_or_path=model_id,
171
- subfolder=subfolder,
153
+ model = super().get_pytorch_model(
154
+ model_id=model_id,
155
+ use_auth_token=use_auth_token,
172
156
  revision=revision,
173
- framework="pt",
157
+ force_download=force_download,
174
158
  cache_dir=cache_dir,
175
- use_auth_token=use_auth_token,
159
+ subfolder=subfolder,
176
160
  local_files_only=local_files_only,
177
- force_download=force_download,
178
161
  trust_remote_code=trust_remote_code,
162
+ rbln_config_kwargs=rbln_config_kwargs,
163
+ rbln_constructor_kwargs=rbln_constructor_kwargs,
179
164
  **kwargs,
180
165
  )
181
166
 
182
- if config is None:
183
- config = model.config
167
+ unwrap_llama()
184
168
 
185
- config.save_pretrained(save_dir_path)
186
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
169
+ return model
187
170
 
188
- # Get compilation arguments
189
- if rbln_config_kwargs.get("rbln_config", None) is None:
190
- rbln_config = cls.get_rbln_config(
191
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
192
- )
171
+ @classmethod
172
+ @torch.inference_mode()
173
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
174
+ use_continuous_batch = rbln_config.meta["rbln_batching"] == "vllm"
193
175
 
194
- def compile_llama():
195
- wrapped_model = LlamaWrapper(model).eval()
176
+ wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
196
177
 
197
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
198
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
178
+ wrapped_model = wrapper_cls(model).eval()
199
179
 
200
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
201
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
180
+ prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
181
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
202
182
 
203
- prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs)
204
- dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs)
183
+ prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
184
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
205
185
 
206
- prefill_ir = rebel.torchscript_to_ir(
207
- prefill_scripted_model,
208
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
209
- )
210
- dec_ir = rebel.torchscript_to_ir(
211
- dec_scripted_model,
212
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
213
- )
186
+ if use_continuous_batch:
187
+ batch_index_index = 3
188
+ dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
214
189
 
215
- # Caching prefill_decoder/decoder I/O
216
- connections = [
217
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i])
218
- for i in range(model.config.num_hidden_layers * 2)
219
- ]
190
+ wrap_llama_cb() if use_continuous_batch else wrap_llama()
220
191
 
221
- compiled_model = rebel.compile(
222
- prefill_ir,
223
- dec_ir,
224
- connections=connections,
225
- fusion=prefill_rbln_runtime_config.fusion,
226
- npu=prefill_rbln_runtime_config.npu,
227
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
228
- use_weight_sharing=True,
229
- )
230
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
192
+ prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs, check_trace=False)
193
+ dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
231
194
 
232
- compile_llama()
233
- unwrap_llama(origin_mehtods)
195
+ unwrap_llama()
234
196
 
235
- rbln_config.save(save_dir_path)
197
+ prefill_ir = rebel.torchscript_to_ir(
198
+ prefill_scripted_model,
199
+ input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
200
+ )
201
+ dec_ir = rebel.torchscript_to_ir(
202
+ dec_scripted_model,
203
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
204
+ )
236
205
 
237
- return cls._from_pretrained(
238
- model_id=save_dir_path,
239
- config=config,
240
- model_save_dir=save_dir,
241
- **rbln_constructor_kwargs,
242
- **kwargs,
206
+ # Caching prefill_decoder/decoder I/O
207
+ cache_index_offset = 4 if use_continuous_batch else 3
208
+ connections = [
209
+ (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
210
+ for i in range(model.config.num_hidden_layers * 2)
211
+ ]
212
+
213
+ compiled_model = rebel.compile(
214
+ prefill_ir,
215
+ dec_ir,
216
+ connections=connections,
217
+ fusion=prefill_rbln_runtime_config.fusion,
218
+ npu=prefill_rbln_runtime_config.npu,
219
+ tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
220
+ use_weight_sharing=True,
243
221
  )
222
+ return compiled_model
244
223
 
245
224
  @classmethod
246
225
  def _get_rbln_config(
@@ -249,6 +228,7 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
249
228
  model_config: "PretrainedConfig",
250
229
  rbln_max_seq_len: Optional[int] = None,
251
230
  rbln_batch_size: Optional[int] = None,
231
+ rbln_batching: Optional[str] = None,
252
232
  ) -> RBLNConfig:
253
233
  meta = {}
254
234
 
@@ -256,21 +236,38 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
256
236
  if rbln_max_seq_len is None:
257
237
  rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
258
238
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
239
+ rbln_batching = "static" if rbln_batching is None else rbln_batching
259
240
 
260
241
  meta["rbln_max_seq_len"] = rbln_max_seq_len
261
242
  meta["rbln_batch_size"] = rbln_batch_size
262
243
  meta["rbln_prefill_chunk_size"] = prefill_chunk_size
244
+ meta["rbln_batching"] = rbln_batching
245
+ use_continuous_batching = meta["rbln_batching"] == "vllm"
246
+
247
+ if rbln_batching not in SUPPORTED_BATCHING_MODES:
248
+ raise ValueError(
249
+ f'rbln_batching="{rbln_batching}" is not a supported batch mode, '
250
+ f"Possible: {SUPPORTED_BATCHING_MODES}"
251
+ )
263
252
 
264
- def get_input_info(query_length):
253
+ def get_input_info(
254
+ batch_size, # should be 1 if continous batch prefill
255
+ query_length,
256
+ continuous_batch=False, # determines the shape of `cache position`
257
+ ):
265
258
  input_info = [
266
- ("input_ids", [rbln_batch_size, query_length], "int64"),
267
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
259
+ ("input_ids", [batch_size, query_length], "int64"),
260
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
268
261
  (
269
262
  "cache_position",
270
- [],
263
+ [batch_size, query_length] if continuous_batch else [],
271
264
  "int32",
272
265
  ),
273
266
  ]
267
+
268
+ if continuous_batch:
269
+ input_info.append(("batch_position", [], "int16"))
270
+
274
271
  input_info.extend(
275
272
  [
276
273
  (
@@ -286,10 +283,19 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
286
283
  for i in range(model_config.num_hidden_layers * 2)
287
284
  ]
288
285
  )
286
+
289
287
  return input_info
290
288
 
291
- prefill_input_info = get_input_info(query_length=prefill_chunk_size)
292
- dec_input_info = get_input_info(query_length=1)
289
+ prefill_input_info = get_input_info(
290
+ batch_size=1 if use_continuous_batching else rbln_batch_size,
291
+ query_length=prefill_chunk_size,
292
+ continuous_batch=use_continuous_batching,
293
+ )
294
+ dec_input_info = get_input_info(
295
+ batch_size=rbln_batch_size,
296
+ query_length=1,
297
+ continuous_batch=use_continuous_batching,
298
+ )
293
299
 
294
300
  prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
295
301
  dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
@@ -303,11 +309,14 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
303
309
 
304
310
  return rbln_config
305
311
 
306
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
312
+ @classmethod
313
+ def _create_runtimes(
314
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
315
+ ) -> List[rebel.Runtime]:
307
316
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
308
317
  return [
309
- self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
310
- self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
318
+ compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
319
+ compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
311
320
  ]
312
321
 
313
322
  def get_decoder(self):
@@ -337,7 +346,6 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
337
346
 
338
347
  # In greedy decoding
339
348
  if past_cached_length == 0:
340
-
341
349
  # padding with prefill_chunk_size
342
350
  # TODO left padding + left padding has issue on stoppingcriteria(max_len)
343
351
  if cur_len % self.prefill_chunk_size != 0:
@@ -384,7 +392,13 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
384
392
 
385
393
  return model_inputs
386
394
 
387
- def forward(
395
+ def forward(self, *args, **kwargs):
396
+ if self.use_continuous_batch:
397
+ return self.forward_cb(*args, **kwargs)
398
+ else:
399
+ return self.forward_static(*args, **kwargs)
400
+
401
+ def forward_static(
388
402
  self,
389
403
  input_ids: torch.LongTensor = None,
390
404
  attention_mask: Optional[torch.Tensor] = None,
@@ -393,7 +407,6 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
393
407
  query_length: Optional[torch.Tensor] = None,
394
408
  **kwargs,
395
409
  ) -> Tuple[torch.FloatTensor]:
396
-
397
410
  if past_key_values is not None:
398
411
  past_key_values += query_length
399
412
 
@@ -425,3 +438,58 @@ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
425
438
  logits=outputs,
426
439
  past_key_values=past_key_values,
427
440
  )
441
+
442
+ def forward_cb(
443
+ self,
444
+ input_ids: torch.LongTensor = None,
445
+ cache_position: Optional[torch.Tensor] = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
446
+ batch_idx: int = None,
447
+ **kwargs,
448
+ ) -> Tuple[torch.FloatTensor]:
449
+ # prefill_decoder
450
+ if cache_position.shape[1] > 1:
451
+ query_length = input_ids.shape[1]
452
+ attention_mask = self.prefill_attention_mask.clone()
453
+ for step in range(0, query_length, self.prefill_chunk_size):
454
+ if step + self.prefill_chunk_size > query_length:
455
+ input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
456
+ cache_position = torch.cat(
457
+ [
458
+ cache_position,
459
+ torch.arange(
460
+ query_length,
461
+ step + self.prefill_chunk_size,
462
+ dtype=torch.int32,
463
+ ).unsqueeze(0),
464
+ ],
465
+ dim=-1,
466
+ )
467
+
468
+ sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
469
+ sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
470
+ attention_mask[:, :, :, :step] = 1
471
+ attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
472
+
473
+ outputs, _ = self.prefill_decoder(
474
+ sliced_input_ids.contiguous(),
475
+ attention_mask.contiguous(),
476
+ sliced_cache_positions.contiguous(),
477
+ torch.tensor(batch_idx, dtype=torch.int16),
478
+ )
479
+ outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
480
+ # decoder
481
+ else:
482
+ attention_mask = self.decoder_attention_mask.clone()
483
+ for b_idx in range(self.batch_size):
484
+ attention_mask[b_idx, :, :, : cache_position[b_idx].item() + 1] = 1
485
+
486
+ outputs = self.decoder(
487
+ input_ids.contiguous(),
488
+ attention_mask.contiguous(),
489
+ cache_position.contiguous(),
490
+ torch.tensor(0, dtype=torch.int16),
491
+ )[0]
492
+
493
+ return CausalLMOutputWithPast(
494
+ logits=outputs,
495
+ )
@@ -0,0 +1,32 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import os
25
+ from os import environ
26
+
27
+
28
+ this_path = os.path.abspath(__file__)
29
+ local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
30
+ environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
31
+
32
+ from .modeling_midm import RBLNMidmLMHeadModel
@@ -0,0 +1,22 @@
1
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
2
+
3
+
4
+ class MidmBitextConfig(GPT2Config):
5
+ model_type = "midm-bitext-S"
6
+
7
+ def __init__(
8
+ self,
9
+ use_absolute_position_embedding: bool = True,
10
+ use_rotary_position_embedding: bool = False,
11
+ rotary_percentage: float = 1.0,
12
+ normalization_type: str = "layernorm",
13
+ scale_qk_by_inverse_layer_idx: bool = False,
14
+ *args,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(*args, **kwargs)
18
+ self.use_absolute_position_embedding = use_absolute_position_embedding
19
+ self.use_rotary_position_embedding = use_rotary_position_embedding
20
+ self.rotary_percentage = rotary_percentage
21
+ self.normalization_type = normalization_type
22
+ self.scale_qk_by_inverse_layer_idx = scale_qk_by_inverse_layer_idx