optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +21 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +282 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
- optimum_rbln-0.1.8.dist-info/RECORD +73 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
- optimum_rbln-0.1.4.dist-info/RECORD +0 -63
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,26 +23,16 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from
|
27
|
-
from tempfile import TemporaryDirectory
|
28
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
29
27
|
|
30
|
-
import
|
31
|
-
import torch
|
32
|
-
from optimum.exporters import TasksManager
|
33
|
-
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
|
34
|
-
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
|
28
|
+
from transformers import GPT2LMHeadModel, PretrainedConfig, PreTrainedModel
|
35
29
|
|
36
|
-
from ....
|
37
|
-
from
|
38
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
39
|
-
from ....utils.save_utils import maybe_save_preprocessors
|
40
|
-
from ...generation.utils import RBLNGenerationMixin
|
30
|
+
from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
|
31
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
41
32
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
42
33
|
|
43
34
|
|
44
35
|
logger = logging.getLogger(__name__)
|
45
|
-
|
46
36
|
if TYPE_CHECKING:
|
47
37
|
from transformers import (
|
48
38
|
AutoFeatureExtractor,
|
@@ -52,19 +42,12 @@ if TYPE_CHECKING:
|
|
52
42
|
)
|
53
43
|
|
54
44
|
|
55
|
-
class
|
56
|
-
def forward(self, *args, **kwargs) -> Union[Tuple, Seq2SeqLMOutput]:
|
57
|
-
outputs = super().forward(*args, **kwargs)
|
58
|
-
logits = outputs
|
59
|
-
return Seq2SeqLMOutput(logits=logits)
|
60
|
-
|
61
|
-
|
62
|
-
class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
45
|
+
class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
63
46
|
"""
|
64
47
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
65
48
|
embeddings).
|
66
49
|
|
67
|
-
This model inherits from [`
|
50
|
+
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
|
68
51
|
library implements for all its model.
|
69
52
|
|
70
53
|
It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
|
@@ -73,29 +56,9 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
73
56
|
|
74
57
|
"""
|
75
58
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
def __post_init__(self, **kwargs):
|
81
|
-
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
82
|
-
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
83
|
-
|
84
|
-
batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].input_info[0][1][0]
|
85
|
-
self.prefill_attention_mask = torch.zeros(
|
86
|
-
batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
87
|
-
)
|
88
|
-
self.causal_mask = 1 - torch.triu(
|
89
|
-
torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
90
|
-
)
|
91
|
-
|
92
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0])
|
93
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1])
|
94
|
-
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
95
|
-
self.past_cached_length = 0
|
96
|
-
|
97
|
-
def can_generate(self):
|
98
|
-
return True
|
59
|
+
@classmethod
|
60
|
+
def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
|
61
|
+
return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
99
62
|
|
100
63
|
def __getattr__(self, __name: str) -> Any:
|
101
64
|
"""This is the key method to implement RBLN-GPT2.
|
@@ -112,126 +75,6 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
112
75
|
return redirect(val)
|
113
76
|
return val
|
114
77
|
|
115
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
116
|
-
# TODO(jongho): implement
|
117
|
-
raise NotImplementedError
|
118
|
-
|
119
|
-
@classmethod
|
120
|
-
def _export(
|
121
|
-
cls,
|
122
|
-
model_id: str,
|
123
|
-
config: "PretrainedConfig",
|
124
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
125
|
-
revision: Optional[str] = None,
|
126
|
-
force_download: bool = False,
|
127
|
-
cache_dir: Optional[str] = None,
|
128
|
-
subfolder: str = "",
|
129
|
-
local_files_only: bool = False,
|
130
|
-
trust_remote_code: bool = False,
|
131
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
132
|
-
**kwargs,
|
133
|
-
) -> "RBLNGPT2LMHeadModel":
|
134
|
-
"""
|
135
|
-
Exports a vanilla Transformers model into a rbln-compiled Module.
|
136
|
-
"""
|
137
|
-
task = kwargs.pop("task", None)
|
138
|
-
if task is None:
|
139
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
140
|
-
|
141
|
-
if model_save_dir is None:
|
142
|
-
save_dir = TemporaryDirectory()
|
143
|
-
save_dir_path = Path(save_dir.name)
|
144
|
-
else:
|
145
|
-
save_dir = model_save_dir
|
146
|
-
if isinstance(save_dir, TemporaryDirectory):
|
147
|
-
save_dir_path = Path(model_save_dir.name)
|
148
|
-
else:
|
149
|
-
save_dir_path = Path(model_save_dir)
|
150
|
-
save_dir_path.mkdir(exist_ok=True)
|
151
|
-
|
152
|
-
kwargs.update(
|
153
|
-
{
|
154
|
-
"torchscript": True,
|
155
|
-
"return_dict": False,
|
156
|
-
"use_cache": True,
|
157
|
-
}
|
158
|
-
)
|
159
|
-
|
160
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
161
|
-
|
162
|
-
model: GPT2LMHeadModel = TasksManager.get_model_from_task(
|
163
|
-
task=task,
|
164
|
-
model_name_or_path=model_id,
|
165
|
-
subfolder=subfolder,
|
166
|
-
revision=revision,
|
167
|
-
framework="pt",
|
168
|
-
cache_dir=cache_dir,
|
169
|
-
use_auth_token=use_auth_token,
|
170
|
-
local_files_only=local_files_only,
|
171
|
-
force_download=force_download,
|
172
|
-
trust_remote_code=trust_remote_code,
|
173
|
-
**kwargs,
|
174
|
-
)
|
175
|
-
|
176
|
-
if config is None:
|
177
|
-
config = model.config
|
178
|
-
|
179
|
-
config.save_pretrained(save_dir_path)
|
180
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
181
|
-
|
182
|
-
# Get compilation arguments
|
183
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
184
|
-
rbln_config = cls.get_rbln_config(
|
185
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
186
|
-
)
|
187
|
-
|
188
|
-
def compile_gpt2():
|
189
|
-
wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
|
190
|
-
|
191
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
192
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
193
|
-
|
194
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
195
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
196
|
-
|
197
|
-
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
|
198
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
|
199
|
-
|
200
|
-
prefill_ir = rebel.torchscript_to_ir(
|
201
|
-
prefill_scripted_model,
|
202
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
203
|
-
)
|
204
|
-
dec_ir = rebel.torchscript_to_ir(
|
205
|
-
dec_scripted_model,
|
206
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
207
|
-
)
|
208
|
-
|
209
|
-
connections = [
|
210
|
-
(prefill_ir.outputs[1], prefill_ir.inputs[1]),
|
211
|
-
]
|
212
|
-
|
213
|
-
compiled_model = rebel.compile(
|
214
|
-
prefill_ir,
|
215
|
-
dec_ir,
|
216
|
-
connections=connections,
|
217
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
218
|
-
npu=prefill_rbln_runtime_config.npu,
|
219
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
220
|
-
use_weight_sharing=True,
|
221
|
-
)
|
222
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
223
|
-
|
224
|
-
compile_gpt2()
|
225
|
-
rbln_config.save(save_dir_path)
|
226
|
-
|
227
|
-
return cls._from_pretrained(
|
228
|
-
model_id=save_dir_path,
|
229
|
-
config=config,
|
230
|
-
model_save_dir=save_dir,
|
231
|
-
**rbln_constructor_kwargs,
|
232
|
-
**kwargs,
|
233
|
-
)
|
234
|
-
|
235
78
|
@classmethod
|
236
79
|
def _get_rbln_config(
|
237
80
|
cls,
|
@@ -239,153 +82,74 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
239
82
|
model_config: "PretrainedConfig",
|
240
83
|
rbln_max_seq_len: Optional[int] = None,
|
241
84
|
rbln_batch_size: Optional[int] = None,
|
242
|
-
|
85
|
+
**kwargs,
|
243
86
|
) -> RBLNConfig:
|
244
87
|
meta = {}
|
245
88
|
|
246
|
-
default_max_length = getattr(model_config, "n_positions", None)
|
247
|
-
for tokenizer in preprocessors:
|
248
|
-
default_max_length = default_max_length or getattr(tokenizer, "max_len_single_sentence", None)
|
249
|
-
|
250
89
|
prefill_chunk_size = 128
|
90
|
+
if rbln_max_seq_len is None: # differenct from llama
|
91
|
+
rbln_max_seq_len = getattr(model_config, "n_positions", None)
|
92
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
251
93
|
|
252
|
-
if rbln_max_seq_len is None:
|
253
|
-
rbln_max_seq_len = default_max_length
|
254
|
-
|
255
|
-
if rbln_max_seq_len is None:
|
256
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
257
|
-
|
258
|
-
if rbln_pad_token_id is None:
|
259
|
-
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
260
|
-
if rbln_pad_token_id is None:
|
261
|
-
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
262
|
-
if rbln_pad_token_id is None:
|
263
|
-
rbln_pad_token_id = 50256
|
264
|
-
|
265
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
266
94
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
267
|
-
meta["
|
268
|
-
|
269
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
95
|
+
meta["rbln_batch_size"] = rbln_batch_size
|
96
|
+
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
270
97
|
|
271
|
-
def get_input_info(
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
],
|
284
|
-
"float32",
|
285
|
-
),
|
286
|
-
("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
98
|
+
def get_input_info(
|
99
|
+
batch_size,
|
100
|
+
query_length,
|
101
|
+
):
|
102
|
+
head_dim = (
|
103
|
+
model_config.head_dim
|
104
|
+
if hasattr(model_config, "head_dim")
|
105
|
+
else model_config.hidden_size // model_config.n_head
|
106
|
+
)
|
107
|
+
input_info = [
|
108
|
+
("input_ids", [batch_size, query_length], "int64"),
|
109
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
287
110
|
(
|
288
111
|
"cache_position",
|
289
|
-
[],
|
112
|
+
[batch_size, query_length],
|
290
113
|
"int32",
|
291
114
|
),
|
115
|
+
("batch_position", [], "int16"),
|
292
116
|
]
|
293
117
|
|
294
|
-
|
295
|
-
|
296
|
-
|
118
|
+
input_info.extend(
|
119
|
+
[
|
120
|
+
(
|
121
|
+
f"past_key_values_{i}",
|
122
|
+
[
|
123
|
+
rbln_batch_size,
|
124
|
+
model_config.n_head, # differenct from llama
|
125
|
+
rbln_max_seq_len,
|
126
|
+
head_dim,
|
127
|
+
],
|
128
|
+
"float32",
|
129
|
+
)
|
130
|
+
for i in range(model_config.n_layer * 2) # differenct from llama
|
131
|
+
]
|
132
|
+
)
|
133
|
+
|
134
|
+
return input_info
|
135
|
+
|
136
|
+
prefill_input_info = get_input_info(
|
137
|
+
batch_size=1,
|
138
|
+
query_length=prefill_chunk_size,
|
139
|
+
)
|
140
|
+
dec_input_info = get_input_info(
|
141
|
+
batch_size=rbln_batch_size,
|
142
|
+
query_length=1,
|
143
|
+
)
|
297
144
|
|
298
145
|
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
299
146
|
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
300
147
|
|
148
|
+
dec_rbln_runtime_config.batch_size = rbln_batch_size
|
149
|
+
|
301
150
|
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
302
151
|
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
303
152
|
_rbln_meta=meta,
|
304
153
|
)
|
305
154
|
|
306
155
|
return rbln_config
|
307
|
-
|
308
|
-
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
309
|
-
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
310
|
-
return [
|
311
|
-
self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
312
|
-
self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
313
|
-
]
|
314
|
-
|
315
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
316
|
-
batch_size, cur_len = input_ids.shape
|
317
|
-
past_cached_length = past_key_values
|
318
|
-
|
319
|
-
# In greedy decoding
|
320
|
-
if past_cached_length == 0:
|
321
|
-
self.prompt_ids = input_ids
|
322
|
-
self.rightpad_max_len = cur_len
|
323
|
-
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
324
|
-
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
|
325
|
-
|
326
|
-
if cur_len % self.prefill_chunk_size == 0:
|
327
|
-
pad_len = 0
|
328
|
-
else:
|
329
|
-
pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
|
330
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
|
331
|
-
attention_mask = self.prefill_attention_mask.clone()
|
332
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
333
|
-
|
334
|
-
query_length = prompt_min_len.item()
|
335
|
-
else:
|
336
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
337
|
-
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
338
|
-
attention_mask[:, :, :, : cache_position + 1] = 1
|
339
|
-
input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
|
340
|
-
query_length = 1
|
341
|
-
|
342
|
-
model_inputs = {
|
343
|
-
"input_ids": input_ids,
|
344
|
-
"past_key_values": past_key_values,
|
345
|
-
"attention_mask": attention_mask,
|
346
|
-
# below are rbln-related kwargs
|
347
|
-
"cache_position": cache_position,
|
348
|
-
"query_length": query_length,
|
349
|
-
}
|
350
|
-
|
351
|
-
return model_inputs
|
352
|
-
|
353
|
-
def forward(
|
354
|
-
self,
|
355
|
-
input_ids: Optional[torch.LongTensor] = None,
|
356
|
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
357
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
358
|
-
cache_position: Optional[torch.Tensor] = None,
|
359
|
-
query_length: Optional[torch.Tensor] = None,
|
360
|
-
**kwargs,
|
361
|
-
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
362
|
-
if past_key_values is not None:
|
363
|
-
past_key_values += query_length
|
364
|
-
|
365
|
-
if cache_position == 0:
|
366
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
367
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
368
|
-
attention_mask[:, :, :, :step] = 1
|
369
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
370
|
-
|
371
|
-
output = self.prefill_decoder(
|
372
|
-
input_ids=sliced_input_ids.contiguous(),
|
373
|
-
attention_mask=attention_mask.contiguous(),
|
374
|
-
cache_position=cache_position + step,
|
375
|
-
)
|
376
|
-
|
377
|
-
idx = query_length % self.prefill_chunk_size - 1
|
378
|
-
output = output.logits[:, idx].unsqueeze(1)
|
379
|
-
|
380
|
-
else:
|
381
|
-
output = self.decoder(
|
382
|
-
input_ids=input_ids.contiguous(),
|
383
|
-
attention_mask=attention_mask.contiguous(),
|
384
|
-
cache_position=cache_position,
|
385
|
-
)
|
386
|
-
output = output.logits
|
387
|
-
|
388
|
-
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
|
389
|
-
|
390
|
-
def __repr__(self):
|
391
|
-
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|