optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +110 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,54 +21,20 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import inspect
|
24
|
+
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import
|
26
|
+
from typing import Any, Callable
|
27
27
|
|
28
|
-
import
|
29
|
-
import rebel # noqa: F401
|
28
|
+
from transformers import LlamaForCausalLM, PreTrainedModel
|
30
29
|
|
31
|
-
from
|
32
|
-
from
|
33
|
-
|
34
|
-
from ...generation.utils import RBLNGenerationMixin
|
35
|
-
from ....modeling_base import RBLNModel
|
36
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
37
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
38
|
-
|
39
|
-
|
40
|
-
# FIXME:: Merge Two architecture Codes
|
41
|
-
from .llama_architecture import (
|
42
|
-
LlamaWrapper,
|
43
|
-
wrap_llama,
|
44
|
-
unwrap_llama,
|
45
|
-
)
|
46
|
-
|
47
|
-
from .llama_architecture_cb import (
|
48
|
-
LlamaDynamicBatchWrapper as LlamaWrapper_cb,
|
49
|
-
wrap_llama as wrap_llama_cb,
|
50
|
-
)
|
30
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
|
+
from .llama_architecture import LlamaWrapper
|
51
32
|
|
52
33
|
|
53
34
|
logger = logging.getLogger(__name__)
|
54
35
|
|
55
|
-
if TYPE_CHECKING:
|
56
|
-
from transformers import (
|
57
|
-
AutoFeatureExtractor,
|
58
|
-
AutoProcessor,
|
59
|
-
AutoTokenizer,
|
60
|
-
PretrainedConfig,
|
61
|
-
)
|
62
|
-
|
63
|
-
|
64
|
-
SUPPORTED_BATCHING_MODES = ["static", "vllm"]
|
65
|
-
|
66
|
-
|
67
|
-
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
68
|
-
mandatory_members = ["main_input_name"]
|
69
|
-
|
70
36
|
|
71
|
-
class RBLNLlamaForCausalLM(
|
37
|
+
class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
72
38
|
"""
|
73
39
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
74
40
|
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -79,251 +45,9 @@ class RBLNLlamaForCausalLM(RBLNModel, RBLNGenerationMixin):
|
|
79
45
|
- compiling the resulting graph using the RBLN compiler.
|
80
46
|
"""
|
81
47
|
|
82
|
-
main_input_name = "input_ids"
|
83
|
-
auto_model_class = AutoModelForCausalLM
|
84
|
-
|
85
|
-
def __post_init__(self, **kwargs):
|
86
|
-
self.batch_size = self.rbln_config.meta["rbln_batch_size"]
|
87
|
-
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
88
|
-
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
89
|
-
self.use_continuous_batch = self.rbln_config.meta["rbln_batching"] == "vllm"
|
90
|
-
|
91
|
-
prefill_batch_size = self.batch_size if not self.use_continuous_batch else 1
|
92
|
-
self.prefill_attention_mask = torch.zeros(
|
93
|
-
prefill_batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
94
|
-
)
|
95
|
-
self.causal_mask = 1 - torch.triu(
|
96
|
-
torch.ones(prefill_batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
97
|
-
)
|
98
|
-
self.decoder_attention_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
99
|
-
|
100
|
-
self.prefill_decoder = RBLNRuntimeModel(runtime=self.model[0], main_input_name="input_ids")
|
101
|
-
self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
|
102
|
-
self.past_cached_length = 0
|
103
|
-
self.right_padding = True
|
104
|
-
|
105
|
-
@classmethod
|
106
|
-
def update_kwargs(cls, kwargs):
|
107
|
-
"""
|
108
|
-
Update user-given kwargs to get proper pytorch model.
|
109
|
-
|
110
|
-
For example, `torchscript`=True should be set because torch.jit
|
111
|
-
does not support `transformers` output instances as module output;
|
112
|
-
"""
|
113
|
-
kwargs.update(
|
114
|
-
{
|
115
|
-
"torchscript": True,
|
116
|
-
"return_dict": False,
|
117
|
-
"use_cache": True,
|
118
|
-
"torch_dtype": torch.float32,
|
119
|
-
"_attn_implementation": "eager",
|
120
|
-
}
|
121
|
-
)
|
122
|
-
return kwargs
|
123
|
-
|
124
|
-
@classmethod
|
125
|
-
def get_pytorch_model(
|
126
|
-
cls,
|
127
|
-
model_id: str,
|
128
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
129
|
-
revision: Optional[str] = None,
|
130
|
-
force_download: bool = False,
|
131
|
-
cache_dir: Optional[str] = None,
|
132
|
-
subfolder: str = "",
|
133
|
-
local_files_only: bool = False,
|
134
|
-
trust_remote_code: bool = False,
|
135
|
-
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
136
|
-
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
137
|
-
**kwargs,
|
138
|
-
) -> PreTrainedModel:
|
139
|
-
if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
|
140
|
-
config = AutoConfig.from_pretrained(model_id)
|
141
|
-
if hf_position_embedding := getattr(config, "max_position_embeddings", None):
|
142
|
-
if hf_position_embedding < rbln_max_seq_len:
|
143
|
-
logger.warning(
|
144
|
-
f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
|
145
|
-
"This may lead to incorrect inferences of the model."
|
146
|
-
)
|
147
|
-
kwargs.update({"max_position_embeddings": rbln_max_seq_len})
|
148
|
-
|
149
|
-
# FIXME :: This should be moved when wrapping removed.
|
150
|
-
use_continuous_batch = rbln_config_kwargs.get("rbln_batching", "static") == "vllm"
|
151
|
-
wrap_llama_cb() if use_continuous_batch else wrap_llama()
|
152
|
-
|
153
|
-
model = super().get_pytorch_model(
|
154
|
-
model_id=model_id,
|
155
|
-
use_auth_token=use_auth_token,
|
156
|
-
revision=revision,
|
157
|
-
force_download=force_download,
|
158
|
-
cache_dir=cache_dir,
|
159
|
-
subfolder=subfolder,
|
160
|
-
local_files_only=local_files_only,
|
161
|
-
trust_remote_code=trust_remote_code,
|
162
|
-
rbln_config_kwargs=rbln_config_kwargs,
|
163
|
-
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
164
|
-
**kwargs,
|
165
|
-
)
|
166
|
-
|
167
|
-
unwrap_llama()
|
168
|
-
|
169
|
-
return model
|
170
|
-
|
171
|
-
@classmethod
|
172
|
-
@torch.inference_mode()
|
173
|
-
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
174
|
-
use_continuous_batch = rbln_config.meta["rbln_batching"] == "vllm"
|
175
|
-
|
176
|
-
wrapper_cls = LlamaWrapper_cb if use_continuous_batch else LlamaWrapper
|
177
|
-
|
178
|
-
wrapped_model = wrapper_cls(model).eval()
|
179
|
-
|
180
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
181
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
182
|
-
|
183
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
184
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
|
185
|
-
|
186
|
-
if use_continuous_batch:
|
187
|
-
batch_index_index = 3
|
188
|
-
dec_example_inputs[batch_index_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
189
|
-
|
190
|
-
wrap_llama_cb() if use_continuous_batch else wrap_llama()
|
191
|
-
|
192
|
-
prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs, check_trace=False)
|
193
|
-
dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
|
194
|
-
|
195
|
-
unwrap_llama()
|
196
|
-
|
197
|
-
prefill_ir = rebel.torchscript_to_ir(
|
198
|
-
prefill_scripted_model,
|
199
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
200
|
-
)
|
201
|
-
dec_ir = rebel.torchscript_to_ir(
|
202
|
-
dec_scripted_model,
|
203
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
204
|
-
)
|
205
|
-
|
206
|
-
# Caching prefill_decoder/decoder I/O
|
207
|
-
cache_index_offset = 4 if use_continuous_batch else 3
|
208
|
-
connections = [
|
209
|
-
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
210
|
-
for i in range(model.config.num_hidden_layers * 2)
|
211
|
-
]
|
212
|
-
|
213
|
-
compiled_model = rebel.compile(
|
214
|
-
prefill_ir,
|
215
|
-
dec_ir,
|
216
|
-
connections=connections,
|
217
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
218
|
-
npu=prefill_rbln_runtime_config.npu,
|
219
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
220
|
-
use_weight_sharing=True,
|
221
|
-
)
|
222
|
-
return compiled_model
|
223
|
-
|
224
|
-
@classmethod
|
225
|
-
def _get_rbln_config(
|
226
|
-
cls,
|
227
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
228
|
-
model_config: "PretrainedConfig",
|
229
|
-
rbln_max_seq_len: Optional[int] = None,
|
230
|
-
rbln_batch_size: Optional[int] = None,
|
231
|
-
rbln_batching: Optional[str] = None,
|
232
|
-
) -> RBLNConfig:
|
233
|
-
meta = {}
|
234
|
-
|
235
|
-
prefill_chunk_size = 128
|
236
|
-
if rbln_max_seq_len is None:
|
237
|
-
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
238
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
239
|
-
rbln_batching = "static" if rbln_batching is None else rbln_batching
|
240
|
-
|
241
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
242
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
243
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
244
|
-
meta["rbln_batching"] = rbln_batching
|
245
|
-
use_continuous_batching = meta["rbln_batching"] == "vllm"
|
246
|
-
|
247
|
-
if rbln_batching not in SUPPORTED_BATCHING_MODES:
|
248
|
-
raise ValueError(
|
249
|
-
f'rbln_batching="{rbln_batching}" is not a supported batch mode, '
|
250
|
-
f"Possible: {SUPPORTED_BATCHING_MODES}"
|
251
|
-
)
|
252
|
-
|
253
|
-
def get_input_info(
|
254
|
-
batch_size, # should be 1 if continous batch prefill
|
255
|
-
query_length,
|
256
|
-
continuous_batch=False, # determines the shape of `cache position`
|
257
|
-
):
|
258
|
-
input_info = [
|
259
|
-
("input_ids", [batch_size, query_length], "int64"),
|
260
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
261
|
-
(
|
262
|
-
"cache_position",
|
263
|
-
[batch_size, query_length] if continuous_batch else [],
|
264
|
-
"int32",
|
265
|
-
),
|
266
|
-
]
|
267
|
-
|
268
|
-
if continuous_batch:
|
269
|
-
input_info.append(("batch_position", [], "int16"))
|
270
|
-
|
271
|
-
input_info.extend(
|
272
|
-
[
|
273
|
-
(
|
274
|
-
f"past_key_values_{i}",
|
275
|
-
[
|
276
|
-
rbln_batch_size,
|
277
|
-
model_config.num_key_value_heads,
|
278
|
-
rbln_max_seq_len,
|
279
|
-
model_config.hidden_size // model_config.num_attention_heads,
|
280
|
-
],
|
281
|
-
"float32",
|
282
|
-
)
|
283
|
-
for i in range(model_config.num_hidden_layers * 2)
|
284
|
-
]
|
285
|
-
)
|
286
|
-
|
287
|
-
return input_info
|
288
|
-
|
289
|
-
prefill_input_info = get_input_info(
|
290
|
-
batch_size=1 if use_continuous_batching else rbln_batch_size,
|
291
|
-
query_length=prefill_chunk_size,
|
292
|
-
continuous_batch=use_continuous_batching,
|
293
|
-
)
|
294
|
-
dec_input_info = get_input_info(
|
295
|
-
batch_size=rbln_batch_size,
|
296
|
-
query_length=1,
|
297
|
-
continuous_batch=use_continuous_batching,
|
298
|
-
)
|
299
|
-
|
300
|
-
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
301
|
-
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
302
|
-
|
303
|
-
dec_rbln_runtime_config.batch_size = rbln_batch_size
|
304
|
-
|
305
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
306
|
-
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
307
|
-
_rbln_meta=meta,
|
308
|
-
)
|
309
|
-
|
310
|
-
return rbln_config
|
311
|
-
|
312
48
|
@classmethod
|
313
|
-
def
|
314
|
-
|
315
|
-
) -> List[rebel.Runtime]:
|
316
|
-
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
317
|
-
return [
|
318
|
-
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
319
|
-
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
320
|
-
]
|
321
|
-
|
322
|
-
def get_decoder(self):
|
323
|
-
return self.decoder
|
324
|
-
|
325
|
-
def can_generate(self):
|
326
|
-
return True
|
49
|
+
def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
|
50
|
+
return LlamaWrapper(model, rbln_max_seq_len).eval()
|
327
51
|
|
328
52
|
def __getattr__(self, __name: str) -> Any:
|
329
53
|
def redirect(func):
|
@@ -335,161 +59,3 @@ class RBLNLlamaForCausalLM(RBLNModel, RBLNGenerationMixin):
|
|
335
59
|
return redirect(val)
|
336
60
|
|
337
61
|
return val
|
338
|
-
|
339
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
340
|
-
raise NotImplementedError
|
341
|
-
|
342
|
-
# args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
|
343
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
344
|
-
batch_size, cur_len = input_ids.shape
|
345
|
-
past_cached_length = past_key_values
|
346
|
-
|
347
|
-
# In greedy decoding
|
348
|
-
if past_cached_length == 0:
|
349
|
-
# padding with prefill_chunk_size
|
350
|
-
# TODO left padding + left padding has issue on stoppingcriteria(max_len)
|
351
|
-
if cur_len % self.prefill_chunk_size != 0:
|
352
|
-
pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
|
353
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
|
354
|
-
|
355
|
-
# padding_side
|
356
|
-
if batch_size > 1 and torch.all(attention_mask[..., -1] == 1):
|
357
|
-
self.right_padding = False
|
358
|
-
|
359
|
-
if self.right_padding:
|
360
|
-
self.rightpad_max_len = cur_len
|
361
|
-
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
362
|
-
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len # dummy_decoder generation length
|
363
|
-
query_length = prompt_min_len.item()
|
364
|
-
else:
|
365
|
-
query_length = cur_len - past_cached_length
|
366
|
-
self.prompt_length = query_length
|
367
|
-
self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
|
368
|
-
|
369
|
-
attention_mask = self.prefill_attention_mask.clone()
|
370
|
-
cache_position = torch.tensor(0, dtype=torch.int32)
|
371
|
-
|
372
|
-
else:
|
373
|
-
if self.right_padding:
|
374
|
-
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
375
|
-
attention_mask[:, :, :, : past_cached_length + 1] = 1
|
376
|
-
input_ids = input_ids[:, past_cached_length : past_cached_length + 1].contiguous()
|
377
|
-
else:
|
378
|
-
attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - cur_len))
|
379
|
-
attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
380
|
-
input_ids = input_ids[:, -1:]
|
381
|
-
|
382
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
383
|
-
query_length = 1
|
384
|
-
|
385
|
-
model_inputs = {
|
386
|
-
"input_ids": input_ids,
|
387
|
-
"past_key_values": past_key_values,
|
388
|
-
"attention_mask": attention_mask,
|
389
|
-
"cache_position": cache_position,
|
390
|
-
"query_length": query_length,
|
391
|
-
}
|
392
|
-
|
393
|
-
return model_inputs
|
394
|
-
|
395
|
-
def forward(self, *args, **kwargs):
|
396
|
-
if self.use_continuous_batch:
|
397
|
-
return self.forward_cb(*args, **kwargs)
|
398
|
-
else:
|
399
|
-
return self.forward_static(*args, **kwargs)
|
400
|
-
|
401
|
-
def forward_static(
|
402
|
-
self,
|
403
|
-
input_ids: torch.LongTensor = None,
|
404
|
-
attention_mask: Optional[torch.Tensor] = None,
|
405
|
-
past_key_values: int = None,
|
406
|
-
cache_position: Optional[torch.Tensor] = None,
|
407
|
-
query_length: Optional[torch.Tensor] = None,
|
408
|
-
**kwargs,
|
409
|
-
) -> Tuple[torch.FloatTensor]:
|
410
|
-
if past_key_values is not None:
|
411
|
-
past_key_values += query_length
|
412
|
-
|
413
|
-
# prefill_decoder
|
414
|
-
if cache_position == 0:
|
415
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
416
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
417
|
-
attention_mask[:, :, :, :step] = 1
|
418
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
419
|
-
if not self.right_padding:
|
420
|
-
attention_mask[:, :, :, : self.prompt_length] &= self.prompt_attn_mask[:, :, :, :]
|
421
|
-
|
422
|
-
outputs = self.prefill_decoder(
|
423
|
-
input_ids=sliced_input_ids.contiguous(),
|
424
|
-
attention_mask=attention_mask.contiguous(),
|
425
|
-
cache_position=cache_position + step,
|
426
|
-
)
|
427
|
-
outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
428
|
-
|
429
|
-
# decoder
|
430
|
-
else:
|
431
|
-
outputs = self.decoder(
|
432
|
-
input_ids.contiguous(),
|
433
|
-
attention_mask.contiguous(),
|
434
|
-
cache_position=cache_position,
|
435
|
-
)
|
436
|
-
|
437
|
-
return CausalLMOutputWithPast(
|
438
|
-
logits=outputs,
|
439
|
-
past_key_values=past_key_values,
|
440
|
-
)
|
441
|
-
|
442
|
-
def forward_cb(
|
443
|
-
self,
|
444
|
-
input_ids: torch.LongTensor = None,
|
445
|
-
cache_position: Optional[torch.Tensor] = None, # torch.tensor(,dtype=int32) (1,64) // (4,1)
|
446
|
-
batch_idx: int = None,
|
447
|
-
**kwargs,
|
448
|
-
) -> Tuple[torch.FloatTensor]:
|
449
|
-
# prefill_decoder
|
450
|
-
if cache_position.shape[1] > 1:
|
451
|
-
query_length = input_ids.shape[1]
|
452
|
-
attention_mask = self.prefill_attention_mask.clone()
|
453
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
454
|
-
if step + self.prefill_chunk_size > query_length:
|
455
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, step + self.prefill_chunk_size - query_length))
|
456
|
-
cache_position = torch.cat(
|
457
|
-
[
|
458
|
-
cache_position,
|
459
|
-
torch.arange(
|
460
|
-
query_length,
|
461
|
-
step + self.prefill_chunk_size,
|
462
|
-
dtype=torch.int32,
|
463
|
-
).unsqueeze(0),
|
464
|
-
],
|
465
|
-
dim=-1,
|
466
|
-
)
|
467
|
-
|
468
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
469
|
-
sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
|
470
|
-
attention_mask[:, :, :, :step] = 1
|
471
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
472
|
-
|
473
|
-
outputs, _ = self.prefill_decoder(
|
474
|
-
sliced_input_ids.contiguous(),
|
475
|
-
attention_mask.contiguous(),
|
476
|
-
sliced_cache_positions.contiguous(),
|
477
|
-
torch.tensor(batch_idx, dtype=torch.int16),
|
478
|
-
)
|
479
|
-
outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
480
|
-
# decoder
|
481
|
-
else:
|
482
|
-
attention_mask = self.decoder_attention_mask.clone()
|
483
|
-
for b_idx in range(self.batch_size):
|
484
|
-
attention_mask[b_idx, :, :, : cache_position[b_idx].item() + 1] = 1
|
485
|
-
|
486
|
-
outputs = self.decoder(
|
487
|
-
input_ids.contiguous(),
|
488
|
-
attention_mask.contiguous(),
|
489
|
-
cache_position.contiguous(),
|
490
|
-
torch.tensor(0, dtype=torch.int16),
|
491
|
-
)[0]
|
492
|
-
|
493
|
-
return CausalLMOutputWithPast(
|
494
|
-
logits=outputs,
|
495
|
-
)
|
@@ -10,7 +10,8 @@
|
|
10
10
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
11
|
# See the License for the specific language governing permissions and
|
12
12
|
# limitations under the License.
|
13
|
-
"""
|
13
|
+
"""Tokenization class for model Midm_bitext_tonkenizer."""
|
14
|
+
|
14
15
|
import os
|
15
16
|
import re
|
16
17
|
import warnings
|
@@ -817,7 +817,6 @@ class MidmModel(MidmPreTrainedModel):
|
|
817
817
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
818
818
|
all_hidden_states = () if output_hidden_states else None
|
819
819
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
820
|
-
|
821
820
|
# Model parallel
|
822
821
|
if self.model_parallel:
|
823
822
|
torch.cuda.set_device(hidden_states.device)
|
@@ -833,7 +832,6 @@ class MidmModel(MidmPreTrainedModel):
|
|
833
832
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
834
833
|
|
835
834
|
if self.gradient_checkpointing and self.training:
|
836
|
-
|
837
835
|
if use_cache:
|
838
836
|
logger.warning(
|
839
837
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
@@ -1174,7 +1172,6 @@ class MidmDoubleHeadsModel(MidmPreTrainedModel):
|
|
1174
1172
|
return_dict=None,
|
1175
1173
|
**kwargs,
|
1176
1174
|
):
|
1177
|
-
|
1178
1175
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1179
1176
|
|
1180
1177
|
transformer_outputs = self.transformer(
|
@@ -1445,7 +1442,6 @@ def get_submodule(module, target: str): # -> "Module":
|
|
1445
1442
|
mod: torch.nn.Module = module
|
1446
1443
|
|
1447
1444
|
for item in atoms:
|
1448
|
-
|
1449
1445
|
if not hasattr(mod, item):
|
1450
1446
|
raise AttributeError(mod._get_name() + " has no " "attribute `" + item + "`")
|
1451
1447
|
|