optimum-rbln 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
- optimum/rbln/diffusers/models/controlnet.py +93 -23
- optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
- optimum/rbln/diffusers/pipelines/__init__.py +7 -2
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
- optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
- optimum/rbln/modeling_base.py +39 -6
- optimum/rbln/modeling_seq2seq.py +19 -4
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/generation/__init__.py +1 -0
- optimum/rbln/transformers/generation/streamers.py +17 -0
- optimum/rbln/transformers/generation/utils.py +399 -0
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
- optimum/rbln/transformers/models/llama/llama_architecture.py +49 -17
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +759 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +187 -75
- optimum/rbln/transformers/models/midm/__init__.py +32 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +426 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/METADATA +5 -4
- optimum_rbln-0.1.4.dist-info/RECORD +63 -0
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.0.dist-info/RECORD +0 -51
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -34,16 +34,25 @@ from optimum.exporters import TasksManager
|
|
34
34
|
from transformers import AutoModelForCausalLM, LlamaForCausalLM, PretrainedConfig, AutoConfig
|
35
35
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
36
36
|
|
37
|
+
from ...generation.utils import RBLNGenerationMixin
|
37
38
|
from ....modeling_base import RBLNBaseModel
|
38
39
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
39
40
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
40
41
|
from ....utils.save_utils import maybe_save_preprocessors
|
42
|
+
|
43
|
+
|
44
|
+
# FIXME:: Merge Two architecture Codes
|
41
45
|
from .llama_architecture import (
|
42
46
|
LlamaWrapper,
|
43
47
|
wrap_llama,
|
44
48
|
unwrap_llama,
|
45
49
|
)
|
46
50
|
|
51
|
+
from .llama_architecture_cb import (
|
52
|
+
LlamaDynamicBatchWrapper as LlamaWrapper_cb,
|
53
|
+
wrap_llama as wrap_llama_cb,
|
54
|
+
)
|
55
|
+
|
47
56
|
|
48
57
|
logger = logging.getLogger(__name__)
|
49
58
|
|
@@ -56,26 +65,14 @@ if TYPE_CHECKING:
|
|
56
65
|
)
|
57
66
|
|
58
67
|
|
68
|
+
SUPPORTED_BATCHING_MODES = ["static", "vllm"]
|
69
|
+
|
70
|
+
|
59
71
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
60
72
|
mandatory_members = ["main_input_name"]
|
61
73
|
|
62
|
-
# RBLN_Runtimemodule
|
63
|
-
def forward(
|
64
|
-
self,
|
65
|
-
input_ids: torch.LongTensor = None,
|
66
|
-
attention_mask: torch.LongTensor = None,
|
67
|
-
cache_position: torch.Tensor = None,
|
68
|
-
**kwargs: Dict[str, Any],
|
69
|
-
):
|
70
|
-
logits = super().forward(
|
71
|
-
input_ids=input_ids,
|
72
|
-
attention_mask=attention_mask,
|
73
|
-
cache_position=cache_position,
|
74
|
-
)
|
75
|
-
return logits
|
76
|
-
|
77
74
|
|
78
|
-
class RBLNLlamaForCausalLM(RBLNBaseModel):
|
75
|
+
class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
79
76
|
"""
|
80
77
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
81
78
|
This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -91,21 +88,24 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
91
88
|
auto_model_class = AutoModelForCausalLM
|
92
89
|
|
93
90
|
def __post_init__(self, **kwargs):
|
94
|
-
|
95
91
|
self.batch_size = self.rbln_config.meta["rbln_batch_size"]
|
96
92
|
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
97
93
|
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
94
|
+
self.use_continuous_batch = self.rbln_config.meta["rbln_batching"] == "vllm"
|
98
95
|
|
96
|
+
prefill_batch_size = self.batch_size if not self.use_continuous_batch else 1
|
99
97
|
self.prefill_attention_mask = torch.zeros(
|
100
|
-
|
98
|
+
prefill_batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
101
99
|
)
|
102
100
|
self.causal_mask = 1 - torch.triu(
|
103
|
-
torch.ones(
|
101
|
+
torch.ones(prefill_batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
104
102
|
)
|
103
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
105
104
|
|
106
105
|
self.prefill_decoder = RBLNRuntimeModel(runtime=self.runtimes[0], main_input_name="input_ids")
|
107
106
|
self.decoder = RBLNRuntimeModel(runtime=self.runtimes[1], main_input_name="input_ids")
|
108
107
|
self.past_cached_length = 0
|
108
|
+
self.right_padding = True
|
109
109
|
|
110
110
|
@classmethod
|
111
111
|
@torch.no_grad()
|
@@ -120,14 +120,23 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
120
120
|
subfolder: str = "",
|
121
121
|
local_files_only: bool = False,
|
122
122
|
trust_remote_code: bool = False,
|
123
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
123
124
|
**kwargs,
|
124
125
|
) -> "RBLNLlamaForCausalLM":
|
125
126
|
task = kwargs.pop("task", None)
|
126
127
|
if task is None:
|
127
128
|
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
128
129
|
|
129
|
-
|
130
|
-
|
130
|
+
if model_save_dir is None:
|
131
|
+
save_dir = TemporaryDirectory()
|
132
|
+
save_dir_path = Path(save_dir.name)
|
133
|
+
else:
|
134
|
+
save_dir = model_save_dir
|
135
|
+
if isinstance(save_dir, TemporaryDirectory):
|
136
|
+
save_dir_path = Path(model_save_dir.name)
|
137
|
+
else:
|
138
|
+
save_dir_path = Path(model_save_dir)
|
139
|
+
save_dir_path.mkdir(exist_ok=True)
|
131
140
|
|
132
141
|
def update_configs(kwargs):
|
133
142
|
hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
|
@@ -154,7 +163,10 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
154
163
|
|
155
164
|
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
156
165
|
|
157
|
-
|
166
|
+
# FIXME :: This should be moved when wrapping removed.
|
167
|
+
use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
|
168
|
+
origin_mehtods = wrap_llama_cb() if use_continuous_batch else wrap_llama()
|
169
|
+
|
158
170
|
model: LlamaForCausalLM = TasksManager.get_model_from_task(
|
159
171
|
task=task,
|
160
172
|
model_name_or_path=model_id,
|
@@ -181,14 +193,18 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
181
193
|
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
182
194
|
)
|
183
195
|
|
184
|
-
def compile_llama():
|
185
|
-
wrapped_model =
|
196
|
+
def compile_llama(use_continuous_batch, wrapper_cls):
|
197
|
+
wrapped_model = wrapper_cls(model).eval()
|
186
198
|
|
187
199
|
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
188
200
|
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
189
201
|
|
190
202
|
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
191
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=
|
203
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
|
204
|
+
|
205
|
+
if use_continuous_batch:
|
206
|
+
batch_index_index = 3
|
207
|
+
dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
192
208
|
|
193
209
|
prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs)
|
194
210
|
dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs)
|
@@ -203,8 +219,9 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
203
219
|
)
|
204
220
|
|
205
221
|
# Caching prefill_decoder/decoder I/O
|
222
|
+
cache_index_offset = 4 if use_continuous_batch else 3
|
206
223
|
connections = [
|
207
|
-
(prefill_ir.outputs[1 + i], prefill_ir.inputs[
|
224
|
+
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
208
225
|
for i in range(model.config.num_hidden_layers * 2)
|
209
226
|
]
|
210
227
|
|
@@ -219,7 +236,8 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
219
236
|
)
|
220
237
|
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
221
238
|
|
222
|
-
|
239
|
+
wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
|
240
|
+
compile_llama(use_continuous_batch=use_continuous_batch, wrapper_cls=wrapper_cls)
|
223
241
|
unwrap_llama(origin_mehtods)
|
224
242
|
|
225
243
|
rbln_config.save(save_dir_path)
|
@@ -239,27 +257,46 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
239
257
|
model_config: "PretrainedConfig",
|
240
258
|
rbln_max_seq_len: Optional[int] = None,
|
241
259
|
rbln_batch_size: Optional[int] = None,
|
260
|
+
rbln_batching: Optional[str] = None,
|
242
261
|
) -> RBLNConfig:
|
243
262
|
meta = {}
|
244
263
|
|
245
264
|
prefill_chunk_size = 128
|
246
265
|
if rbln_max_seq_len is None:
|
247
266
|
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
267
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
268
|
+
rbln_batching = "static" if rbln_batching is None else rbln_batching
|
248
269
|
|
249
270
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
250
271
|
meta["rbln_batch_size"] = rbln_batch_size
|
251
272
|
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
273
|
+
meta["rbln_batching"] = rbln_batching
|
274
|
+
use_continuous_batching = meta["rbln_batching"] == "vllm"
|
252
275
|
|
253
|
-
|
276
|
+
if rbln_batching not in SUPPORTED_BATCHING_MODES:
|
277
|
+
raise ValueError(
|
278
|
+
f'rbln_batching="{rbln_batching}" is not a supported batch mode, '
|
279
|
+
f"Possible: {SUPPORTED_BATCHING_MODES}"
|
280
|
+
)
|
281
|
+
|
282
|
+
def get_input_info(
|
283
|
+
batch_size, # should be 1 if continous batch prefill
|
284
|
+
query_length,
|
285
|
+
continuous_batch=False, # determines the shape of `cache position`
|
286
|
+
):
|
254
287
|
input_info = [
|
255
|
-
("input_ids", [
|
256
|
-
("attention_mask", [
|
288
|
+
("input_ids", [batch_size, query_length], "int64"),
|
289
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
257
290
|
(
|
258
291
|
"cache_position",
|
259
|
-
[],
|
292
|
+
[batch_size, query_length] if continuous_batch else [],
|
260
293
|
"int32",
|
261
294
|
),
|
262
295
|
]
|
296
|
+
|
297
|
+
if continuous_batch:
|
298
|
+
input_info.append(("batch_position", [], "int16"))
|
299
|
+
|
263
300
|
input_info.extend(
|
264
301
|
[
|
265
302
|
(
|
@@ -275,10 +312,19 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
275
312
|
for i in range(model_config.num_hidden_layers * 2)
|
276
313
|
]
|
277
314
|
)
|
315
|
+
|
278
316
|
return input_info
|
279
317
|
|
280
|
-
prefill_input_info = get_input_info(
|
281
|
-
|
318
|
+
prefill_input_info = get_input_info(
|
319
|
+
batch_size=1 if use_continuous_batching else rbln_batch_size,
|
320
|
+
query_length=prefill_chunk_size,
|
321
|
+
continuous_batch=use_continuous_batching,
|
322
|
+
)
|
323
|
+
dec_input_info = get_input_info(
|
324
|
+
batch_size=rbln_batch_size,
|
325
|
+
query_length=1,
|
326
|
+
continuous_batch=use_continuous_batching,
|
327
|
+
)
|
282
328
|
|
283
329
|
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
284
330
|
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
@@ -321,23 +367,46 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
321
367
|
|
322
368
|
# args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
|
323
369
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
324
|
-
batch_size,
|
370
|
+
batch_size, cur_len = input_ids.shape
|
325
371
|
past_cached_length = past_key_values
|
326
|
-
query_length = hf_input_length - past_cached_length
|
327
372
|
|
328
373
|
# In greedy decoding
|
329
|
-
if
|
330
|
-
|
331
|
-
|
332
|
-
self.
|
333
|
-
|
334
|
-
|
374
|
+
if past_cached_length == 0:
|
375
|
+
# padding with prefill_chunk_size
|
376
|
+
# TODO left padding + left padding has issue on stoppingcriteria(max_len)
|
377
|
+
if cur_len % self.prefill_chunk_size != 0:
|
378
|
+
pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
|
379
|
+
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
|
380
|
+
|
381
|
+
# padding_side
|
382
|
+
if batch_size > 1 and torch.all(attention_mask[..., -1] == 1):
|
383
|
+
self.right_padding = False
|
384
|
+
|
385
|
+
if self.right_padding:
|
386
|
+
self.rightpad_max_len = cur_len
|
387
|
+
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
388
|
+
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len # dummy_decoder generation length
|
389
|
+
query_length = prompt_min_len.item()
|
390
|
+
else:
|
391
|
+
query_length = cur_len - past_cached_length
|
392
|
+
self.prompt_length = query_length
|
393
|
+
self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
|
394
|
+
|
395
|
+
attention_mask = self.prefill_attention_mask.clone()
|
335
396
|
cache_position = torch.tensor(0, dtype=torch.int32)
|
397
|
+
|
336
398
|
else:
|
337
|
-
|
338
|
-
|
399
|
+
if self.right_padding:
|
400
|
+
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
401
|
+
attention_mask[:, :, :, : past_cached_length + 1] = 1
|
402
|
+
input_ids = input_ids[:, past_cached_length : past_cached_length + 1].contiguous()
|
403
|
+
else:
|
404
|
+
attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - cur_len))
|
405
|
+
attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
406
|
+
input_ids = input_ids[:, -1:]
|
407
|
+
|
339
408
|
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
340
|
-
|
409
|
+
query_length = 1
|
341
410
|
|
342
411
|
model_inputs = {
|
343
412
|
"input_ids": input_ids,
|
@@ -349,7 +418,13 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
349
418
|
|
350
419
|
return model_inputs
|
351
420
|
|
352
|
-
def forward(
|
421
|
+
def forward(self, *args, **kwargs):
|
422
|
+
if self.use_continuous_batch:
|
423
|
+
return self.forward_cb(*args, **kwargs)
|
424
|
+
else:
|
425
|
+
return self.forward_static(*args, **kwargs)
|
426
|
+
|
427
|
+
def forward_static(
|
353
428
|
self,
|
354
429
|
input_ids: torch.LongTensor = None,
|
355
430
|
attention_mask: Optional[torch.Tensor] = None,
|
@@ -363,38 +438,20 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
363
438
|
|
364
439
|
# prefill_decoder
|
365
440
|
if cache_position == 0:
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
attention_mask[:, :, :, :
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
sliced_input_ids,
|
375
|
-
attention_mask,
|
376
|
-
cache_position,
|
441
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
442
|
+
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
443
|
+
attention_mask[:, :, :, :step] = 1
|
444
|
+
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
445
|
+
if not self.right_padding:
|
446
|
+
attention_mask[:, :, :, : self.prompt_length] &= self.prompt_attn_mask[:, :, :, :]
|
447
|
+
|
448
|
+
outputs = self.prefill_decoder(
|
449
|
+
input_ids=sliced_input_ids.contiguous(),
|
450
|
+
attention_mask=attention_mask.contiguous(),
|
451
|
+
cache_position=cache_position + step,
|
377
452
|
)
|
378
|
-
|
379
|
-
query_length -= self.prefill_chunk_size
|
380
|
-
cache_position += self.prefill_chunk_size
|
381
|
-
|
382
|
-
# prepare input_ids & attention_mask
|
383
|
-
last_input_ids = input_ids[:, cache_position : cache_position + query_length]
|
384
|
-
last_input_ids = torch.nn.functional.pad(last_input_ids, (0, self.prefill_chunk_size - query_length))
|
453
|
+
outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
385
454
|
|
386
|
-
attention_mask[:, :, :, :cache_position] = 1
|
387
|
-
mask_slice = self.causal_mask[:, :, :query_length, :query_length]
|
388
|
-
attention_mask[:, :, :query_length, cache_position : cache_position + query_length] = mask_slice
|
389
|
-
attention_mask[:, :, :, : self.prompt_length] *= self.prompt_attn_mask[:, :, :, :]
|
390
|
-
|
391
|
-
outputs = self.prefill_decoder(
|
392
|
-
last_input_ids.contiguous(),
|
393
|
-
attention_mask.contiguous(),
|
394
|
-
cache_position,
|
395
|
-
)
|
396
|
-
|
397
|
-
outputs = outputs[:, query_length - 1].unsqueeze(1)
|
398
455
|
# decoder
|
399
456
|
else:
|
400
457
|
outputs = self.decoder(
|
@@ -407,3 +464,58 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
407
464
|
logits=outputs,
|
408
465
|
past_key_values=past_key_values,
|
409
466
|
)
|
467
|
+
|
468
|
+
def forward_cb(
|
469
|
+
self,
|
470
|
+
input_ids: torch.LongTensor = None,
|
471
|
+
cache_position: Optional[torch.Tensor] = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
|
472
|
+
batch_idx: int = None,
|
473
|
+
**kwargs,
|
474
|
+
) -> Tuple[torch.FloatTensor]:
|
475
|
+
# prefill_decoder
|
476
|
+
if cache_position.shape[1] > 1:
|
477
|
+
query_length = input_ids.shape[1]
|
478
|
+
attention_mask = self.prefill_attention_mask.clone()
|
479
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
480
|
+
if step + self.prefill_chunk_size > query_length:
|
481
|
+
input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
|
482
|
+
cache_position = torch.cat(
|
483
|
+
[
|
484
|
+
cache_position,
|
485
|
+
torch.arange(
|
486
|
+
query_length,
|
487
|
+
step + self.prefill_chunk_size,
|
488
|
+
dtype=torch.int32,
|
489
|
+
).unsqueeze(0),
|
490
|
+
],
|
491
|
+
dim=-1,
|
492
|
+
)
|
493
|
+
|
494
|
+
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
495
|
+
sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
|
496
|
+
attention_mask[:, :, :, :step] = 1
|
497
|
+
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
498
|
+
|
499
|
+
outputs, _ = self.prefill_decoder(
|
500
|
+
sliced_input_ids.contiguous(),
|
501
|
+
attention_mask.contiguous(),
|
502
|
+
sliced_cache_positions.contiguous(),
|
503
|
+
torch.tensor(batch_idx, dtype=torch.int16),
|
504
|
+
)
|
505
|
+
outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
506
|
+
# decoder
|
507
|
+
else:
|
508
|
+
attention_mask = self.decoder_attention_mask.clone()
|
509
|
+
for b_idx in range(self.batch_size):
|
510
|
+
attention_mask[b_idx, :, :, : cache_position[b_idx].item() + 1] = 1
|
511
|
+
|
512
|
+
outputs = self.decoder(
|
513
|
+
input_ids.contiguous(),
|
514
|
+
attention_mask.contiguous(),
|
515
|
+
cache_position.contiguous(),
|
516
|
+
torch.tensor(0, dtype=torch.int16),
|
517
|
+
)[0]
|
518
|
+
|
519
|
+
return CausalLMOutputWithPast(
|
520
|
+
logits=outputs,
|
521
|
+
)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import os
|
25
|
+
from os import environ
|
26
|
+
|
27
|
+
|
28
|
+
this_path = os.path.abspath(__file__)
|
29
|
+
local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
|
30
|
+
environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
|
31
|
+
|
32
|
+
from .modeling_midm import RBLNMidmLMHeadModel
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
2
|
+
|
3
|
+
|
4
|
+
class MidmBitextConfig(GPT2Config):
|
5
|
+
model_type = "midm-bitext-S"
|
6
|
+
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
use_absolute_position_embedding: bool = True,
|
10
|
+
use_rotary_position_embedding: bool = False,
|
11
|
+
rotary_percentage: float = 1.0,
|
12
|
+
normalization_type: str = "layernorm",
|
13
|
+
scale_qk_by_inverse_layer_idx: bool = False,
|
14
|
+
*args,
|
15
|
+
**kwargs,
|
16
|
+
):
|
17
|
+
super().__init__(*args, **kwargs)
|
18
|
+
self.use_absolute_position_embedding = use_absolute_position_embedding
|
19
|
+
self.use_rotary_position_embedding = use_rotary_position_embedding
|
20
|
+
self.rotary_percentage = rotary_percentage
|
21
|
+
self.normalization_type = normalization_type
|
22
|
+
self.scale_qk_by_inverse_layer_idx = scale_qk_by_inverse_layer_idx
|