optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +4 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/configuration_utils.py +4 -4
- optimum/rbln/diffusers/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/pipelines/__init__.py +1 -5
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +2 -2
- optimum/rbln/modeling_base.py +12 -4
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/configuration_generic.py +4 -4
- optimum/rbln/transformers/modeling_generic.py +1 -4
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +6 -16
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +19 -250
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
- optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
- optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
- optimum/rbln/utils/runtime_utils.py +3 -3
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_base.py
CHANGED
|
@@ -348,7 +348,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
348
348
|
model_id: Union[str, Path],
|
|
349
349
|
export: bool = False,
|
|
350
350
|
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
351
|
-
**kwargs:
|
|
351
|
+
**kwargs: Any,
|
|
352
352
|
) -> "RBLNBaseModel":
|
|
353
353
|
"""
|
|
354
354
|
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
|
@@ -523,10 +523,18 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
523
523
|
# First copy everything to a temporary directory
|
|
524
524
|
shutil.copytree(real_save_dir, tmp_dir)
|
|
525
525
|
|
|
526
|
-
# If everything succeeded,
|
|
526
|
+
# If everything succeeded, move files to target directory
|
|
527
527
|
if os.path.exists(save_directory_path):
|
|
528
|
-
|
|
529
|
-
|
|
528
|
+
# Move files from tmp_dir to existing directory (overwrite existing files)
|
|
529
|
+
for item in os.listdir(tmp_dir):
|
|
530
|
+
src_path = os.path.join(tmp_dir, item)
|
|
531
|
+
dst_path = os.path.join(save_directory_path, item)
|
|
532
|
+
shutil.move(src_path, dst_path)
|
|
533
|
+
# Clean up empty tmp_dir
|
|
534
|
+
os.rmdir(tmp_dir)
|
|
535
|
+
else:
|
|
536
|
+
# If target doesn't exist, just rename tmp_dir to target
|
|
537
|
+
os.rename(tmp_dir, save_directory_path)
|
|
530
538
|
|
|
531
539
|
except Exception as e:
|
|
532
540
|
# Clean up the temporary directory if anything fails
|
optimum/rbln/ops/attn.py
CHANGED
|
@@ -53,6 +53,45 @@ def paged_attn_decode_fake(
|
|
|
53
53
|
return torch.empty_like(q)
|
|
54
54
|
|
|
55
55
|
|
|
56
|
+
@torch.library.custom_op(
|
|
57
|
+
"rbln_custom_ops::paged_attn_decode_kv_fp8",
|
|
58
|
+
mutates_args=(["kcache", "vcache"]),
|
|
59
|
+
)
|
|
60
|
+
def paged_attn_decode_kv_fp8(
|
|
61
|
+
q: Tensor,
|
|
62
|
+
k: Tensor,
|
|
63
|
+
v: Tensor,
|
|
64
|
+
mask: Tensor,
|
|
65
|
+
kcache: Tensor,
|
|
66
|
+
vcache: Tensor,
|
|
67
|
+
seq: Tensor,
|
|
68
|
+
scale: Tensor,
|
|
69
|
+
block_table: Tensor,
|
|
70
|
+
block_size: int,
|
|
71
|
+
k_scale: Tensor,
|
|
72
|
+
v_scale: Tensor,
|
|
73
|
+
) -> Tensor:
|
|
74
|
+
return torch.empty_like(q)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@paged_attn_decode_kv_fp8.register_fake
|
|
78
|
+
def paged_attn_decode_kv_fp8_fake(
|
|
79
|
+
q: Tensor,
|
|
80
|
+
k: Tensor,
|
|
81
|
+
v: Tensor,
|
|
82
|
+
mask: Tensor,
|
|
83
|
+
kcache: Tensor,
|
|
84
|
+
vcache: Tensor,
|
|
85
|
+
seq: Tensor,
|
|
86
|
+
scale: Tensor,
|
|
87
|
+
block_table: Tensor,
|
|
88
|
+
block_size: int,
|
|
89
|
+
k_scale: Tensor,
|
|
90
|
+
v_scale: Tensor,
|
|
91
|
+
) -> Tensor:
|
|
92
|
+
return torch.empty_like(q)
|
|
93
|
+
|
|
94
|
+
|
|
56
95
|
@torch.library.custom_op(
|
|
57
96
|
"rbln_custom_ops::paged_attn_prefill",
|
|
58
97
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -112,6 +151,45 @@ def paged_attn_prefill_fake(
|
|
|
112
151
|
return torch.empty_like(q)
|
|
113
152
|
|
|
114
153
|
|
|
154
|
+
@torch.library.custom_op(
|
|
155
|
+
"rbln_custom_ops::paged_attn_prefill_kv_fp8",
|
|
156
|
+
mutates_args=(["kcache", "vcache"]),
|
|
157
|
+
)
|
|
158
|
+
def paged_attn_prefill_kv_fp8(
|
|
159
|
+
q: Tensor,
|
|
160
|
+
k: Tensor,
|
|
161
|
+
v: Tensor,
|
|
162
|
+
mask: Tensor,
|
|
163
|
+
kcache: Tensor,
|
|
164
|
+
vcache: Tensor,
|
|
165
|
+
seq: Tensor,
|
|
166
|
+
scale: Tensor,
|
|
167
|
+
block_table: Tensor,
|
|
168
|
+
block_size: int,
|
|
169
|
+
k_scale: Tensor,
|
|
170
|
+
v_scale: Tensor,
|
|
171
|
+
) -> Tensor:
|
|
172
|
+
return torch.empty_like(q)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@paged_attn_prefill_kv_fp8.register_fake
|
|
176
|
+
def paged_attn_prefill_kv_fp8_fake(
|
|
177
|
+
q: Tensor,
|
|
178
|
+
k: Tensor,
|
|
179
|
+
v: Tensor,
|
|
180
|
+
mask: Tensor,
|
|
181
|
+
kcache: Tensor,
|
|
182
|
+
vcache: Tensor,
|
|
183
|
+
seq: Tensor,
|
|
184
|
+
scale: Tensor,
|
|
185
|
+
block_table: Tensor,
|
|
186
|
+
block_size: int,
|
|
187
|
+
k_scale: Tensor,
|
|
188
|
+
v_scale: Tensor,
|
|
189
|
+
) -> Tensor:
|
|
190
|
+
return torch.empty_like(q)
|
|
191
|
+
|
|
192
|
+
|
|
115
193
|
@torch.library.custom_op(
|
|
116
194
|
"rbln_custom_ops::paged_causal_attn_decode",
|
|
117
195
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -236,6 +314,86 @@ def paged_causal_attn_prefill_fake(
|
|
|
236
314
|
return torch.empty_like(q)
|
|
237
315
|
|
|
238
316
|
|
|
317
|
+
@torch.library.custom_op(
|
|
318
|
+
"rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
|
|
319
|
+
mutates_args=(["kcache", "vcache"]),
|
|
320
|
+
)
|
|
321
|
+
def paged_causal_attn_decode_kv_fp8(
|
|
322
|
+
q: Tensor,
|
|
323
|
+
k: Tensor,
|
|
324
|
+
v: Tensor,
|
|
325
|
+
kcache: Tensor,
|
|
326
|
+
vcache: Tensor,
|
|
327
|
+
seq: Tensor,
|
|
328
|
+
scale: Tensor,
|
|
329
|
+
block_table: Tensor,
|
|
330
|
+
block_size: int,
|
|
331
|
+
k_scale: Tensor,
|
|
332
|
+
v_scale: Tensor,
|
|
333
|
+
mask: Optional[Tensor] = None,
|
|
334
|
+
) -> Tensor:
|
|
335
|
+
return torch.empty_like(q)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@paged_causal_attn_decode_kv_fp8.register_fake
|
|
339
|
+
def paged_causal_attn_decode_kv_fp8_fake(
|
|
340
|
+
q: Tensor,
|
|
341
|
+
k: Tensor,
|
|
342
|
+
v: Tensor,
|
|
343
|
+
kcache: Tensor,
|
|
344
|
+
vcache: Tensor,
|
|
345
|
+
seq: Tensor,
|
|
346
|
+
scale: Tensor,
|
|
347
|
+
block_table: Tensor,
|
|
348
|
+
block_size: int,
|
|
349
|
+
k_scale: Tensor,
|
|
350
|
+
v_scale: Tensor,
|
|
351
|
+
mask: Optional[Tensor] = None,
|
|
352
|
+
) -> Tensor:
|
|
353
|
+
return torch.empty_like(q)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@torch.library.custom_op(
|
|
357
|
+
"rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
|
|
358
|
+
mutates_args=(["kcache", "vcache"]),
|
|
359
|
+
)
|
|
360
|
+
def paged_causal_attn_prefill_kv_fp8(
|
|
361
|
+
q: Tensor,
|
|
362
|
+
k: Tensor,
|
|
363
|
+
v: Tensor,
|
|
364
|
+
kcache: Tensor,
|
|
365
|
+
vcache: Tensor,
|
|
366
|
+
seq: Tensor,
|
|
367
|
+
scale: Tensor,
|
|
368
|
+
block_table: Tensor,
|
|
369
|
+
block_size: int,
|
|
370
|
+
is_bidirectional: bool,
|
|
371
|
+
k_scale: Tensor,
|
|
372
|
+
v_scale: Tensor,
|
|
373
|
+
mask: Optional[Tensor] = None,
|
|
374
|
+
) -> Tensor:
|
|
375
|
+
return torch.empty_like(q)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@paged_causal_attn_prefill_kv_fp8.register_fake
|
|
379
|
+
def paged_causal_attn_prefill_kv_fp8_fake(
|
|
380
|
+
q: Tensor,
|
|
381
|
+
k: Tensor,
|
|
382
|
+
v: Tensor,
|
|
383
|
+
kcache: Tensor,
|
|
384
|
+
vcache: Tensor,
|
|
385
|
+
seq: Tensor,
|
|
386
|
+
scale: Tensor,
|
|
387
|
+
block_table: Tensor,
|
|
388
|
+
block_size: int,
|
|
389
|
+
is_bidirectional: bool,
|
|
390
|
+
k_scale: Tensor,
|
|
391
|
+
v_scale: Tensor,
|
|
392
|
+
mask: Optional[Tensor] = None,
|
|
393
|
+
) -> Tensor:
|
|
394
|
+
return torch.empty_like(q)
|
|
395
|
+
|
|
396
|
+
|
|
239
397
|
@torch.library.custom_op(
|
|
240
398
|
"rbln_custom_ops::paged_add_softmax_attn_decode",
|
|
241
399
|
mutates_args=(["kcache", "vcache"]),
|
optimum/rbln/ops/flash_attn.py
CHANGED
|
@@ -59,6 +59,47 @@ def paged_flash_attn_decode_fake(
|
|
|
59
59
|
return torch.empty_like(q)
|
|
60
60
|
|
|
61
61
|
|
|
62
|
+
@torch.library.custom_op(
|
|
63
|
+
"rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
|
|
64
|
+
mutates_args=(["kcache", "vcache"]),
|
|
65
|
+
)
|
|
66
|
+
def paged_flash_attn_decode_kv_fp8(
|
|
67
|
+
q: Tensor,
|
|
68
|
+
k: Tensor,
|
|
69
|
+
v: Tensor,
|
|
70
|
+
mask: Tensor,
|
|
71
|
+
kcache: Tensor,
|
|
72
|
+
vcache: Tensor,
|
|
73
|
+
seq: Tensor,
|
|
74
|
+
scale: Tensor,
|
|
75
|
+
block_table: Tensor,
|
|
76
|
+
block_size: int,
|
|
77
|
+
partition: int,
|
|
78
|
+
k_scale: Tensor,
|
|
79
|
+
v_scale: Tensor,
|
|
80
|
+
) -> Tensor:
|
|
81
|
+
return torch.empty_like(q)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@paged_flash_attn_decode_kv_fp8.register_fake
|
|
85
|
+
def paged_flash_attn_decode_kv_fp8_fake(
|
|
86
|
+
q: Tensor,
|
|
87
|
+
k: Tensor,
|
|
88
|
+
v: Tensor,
|
|
89
|
+
mask: Tensor,
|
|
90
|
+
kcache: Tensor,
|
|
91
|
+
vcache: Tensor,
|
|
92
|
+
seq: Tensor,
|
|
93
|
+
scale: Tensor,
|
|
94
|
+
block_table: Tensor,
|
|
95
|
+
block_size: int,
|
|
96
|
+
partition: int,
|
|
97
|
+
k_scale: Tensor,
|
|
98
|
+
v_scale: Tensor,
|
|
99
|
+
) -> Tensor:
|
|
100
|
+
return torch.empty_like(q)
|
|
101
|
+
|
|
102
|
+
|
|
62
103
|
@torch.library.custom_op(
|
|
63
104
|
"rbln_custom_ops::paged_flash_attn_prefill",
|
|
64
105
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -100,6 +141,47 @@ def paged_flash_attn_prefill_fake(
|
|
|
100
141
|
return torch.empty_like(q)
|
|
101
142
|
|
|
102
143
|
|
|
144
|
+
@torch.library.custom_op(
|
|
145
|
+
"rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
|
|
146
|
+
mutates_args=(["kcache", "vcache"]),
|
|
147
|
+
)
|
|
148
|
+
def paged_flash_attn_prefill_kv_fp8(
|
|
149
|
+
q: Tensor,
|
|
150
|
+
k: Tensor,
|
|
151
|
+
v: Tensor,
|
|
152
|
+
mask: Tensor,
|
|
153
|
+
kcache: Tensor,
|
|
154
|
+
vcache: Tensor,
|
|
155
|
+
seq: Tensor,
|
|
156
|
+
scale: Tensor,
|
|
157
|
+
block_table: Tensor,
|
|
158
|
+
block_size: int,
|
|
159
|
+
partition: int,
|
|
160
|
+
k_scale: Tensor,
|
|
161
|
+
v_scale: Tensor,
|
|
162
|
+
) -> Tensor:
|
|
163
|
+
return torch.empty_like(q)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@paged_flash_attn_prefill_kv_fp8.register_fake
|
|
167
|
+
def paged_flash_attn_prefill_kv_fp8_fake(
|
|
168
|
+
q: Tensor,
|
|
169
|
+
k: Tensor,
|
|
170
|
+
v: Tensor,
|
|
171
|
+
mask: Tensor,
|
|
172
|
+
kcache: Tensor,
|
|
173
|
+
vcache: Tensor,
|
|
174
|
+
seq: Tensor,
|
|
175
|
+
scale: Tensor,
|
|
176
|
+
block_table: Tensor,
|
|
177
|
+
block_size: int,
|
|
178
|
+
partition: int,
|
|
179
|
+
k_scale: Tensor,
|
|
180
|
+
v_scale: Tensor,
|
|
181
|
+
) -> Tensor:
|
|
182
|
+
return torch.empty_like(q)
|
|
183
|
+
|
|
184
|
+
|
|
103
185
|
@torch.library.custom_op(
|
|
104
186
|
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
|
105
187
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -141,6 +223,47 @@ def paged_flash_causal_attn_decode_fake(
|
|
|
141
223
|
return torch.empty_like(q)
|
|
142
224
|
|
|
143
225
|
|
|
226
|
+
@torch.library.custom_op(
|
|
227
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
|
|
228
|
+
mutates_args=(["kcache", "vcache"]),
|
|
229
|
+
)
|
|
230
|
+
def paged_flash_causal_attn_decode_kv_fp8(
|
|
231
|
+
q: Tensor,
|
|
232
|
+
k: Tensor,
|
|
233
|
+
v: Tensor,
|
|
234
|
+
kcache: Tensor,
|
|
235
|
+
vcache: Tensor,
|
|
236
|
+
seq: Tensor,
|
|
237
|
+
scale: Tensor,
|
|
238
|
+
block_table: Tensor,
|
|
239
|
+
block_size: int,
|
|
240
|
+
partition: int,
|
|
241
|
+
k_scale: Tensor,
|
|
242
|
+
v_scale: Tensor,
|
|
243
|
+
mask: Optional[Tensor] = None,
|
|
244
|
+
) -> Tensor:
|
|
245
|
+
return torch.empty_like(q)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@paged_flash_causal_attn_decode_kv_fp8.register_fake
|
|
249
|
+
def paged_flash_causal_attn_decode_kv_fp8_fake(
|
|
250
|
+
q: Tensor,
|
|
251
|
+
k: Tensor,
|
|
252
|
+
v: Tensor,
|
|
253
|
+
kcache: Tensor,
|
|
254
|
+
vcache: Tensor,
|
|
255
|
+
seq: Tensor,
|
|
256
|
+
scale: Tensor,
|
|
257
|
+
block_table: Tensor,
|
|
258
|
+
block_size: int,
|
|
259
|
+
partition: int,
|
|
260
|
+
k_scale: Tensor,
|
|
261
|
+
v_scale: Tensor,
|
|
262
|
+
mask: Optional[Tensor] = None,
|
|
263
|
+
) -> Tensor:
|
|
264
|
+
return torch.empty_like(q)
|
|
265
|
+
|
|
266
|
+
|
|
144
267
|
@torch.library.custom_op(
|
|
145
268
|
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
|
146
269
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -182,3 +305,46 @@ def paged_flash_causal_attn_prefill_fake(
|
|
|
182
305
|
mask: Optional[Tensor] = None,
|
|
183
306
|
) -> Tensor:
|
|
184
307
|
return torch.empty_like(q)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@torch.library.custom_op(
|
|
311
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
|
|
312
|
+
mutates_args=(["kcache", "vcache"]),
|
|
313
|
+
)
|
|
314
|
+
def paged_flash_causal_attn_prefill_kv_fp8(
|
|
315
|
+
q: Tensor,
|
|
316
|
+
k: Tensor,
|
|
317
|
+
v: Tensor,
|
|
318
|
+
kcache: Tensor,
|
|
319
|
+
vcache: Tensor,
|
|
320
|
+
seq: Tensor,
|
|
321
|
+
scale: Tensor,
|
|
322
|
+
block_table: Tensor,
|
|
323
|
+
block_size: int,
|
|
324
|
+
partition: int,
|
|
325
|
+
is_bidirectional: bool,
|
|
326
|
+
k_scale: Tensor,
|
|
327
|
+
v_scale: Tensor,
|
|
328
|
+
mask: Optional[Tensor] = None,
|
|
329
|
+
) -> Tensor:
|
|
330
|
+
return torch.empty_like(q)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@paged_flash_causal_attn_prefill_kv_fp8.register_fake
|
|
334
|
+
def paged_flash_causal_attn_prefill_kv_fp8_fake(
|
|
335
|
+
q: Tensor,
|
|
336
|
+
k: Tensor,
|
|
337
|
+
v: Tensor,
|
|
338
|
+
kcache: Tensor,
|
|
339
|
+
vcache: Tensor,
|
|
340
|
+
seq: Tensor,
|
|
341
|
+
scale: Tensor,
|
|
342
|
+
block_table: Tensor,
|
|
343
|
+
block_size: int,
|
|
344
|
+
partition: int,
|
|
345
|
+
is_bidirectional: bool,
|
|
346
|
+
k_scale: Tensor,
|
|
347
|
+
v_scale: Tensor,
|
|
348
|
+
mask: Optional[Tensor] = None,
|
|
349
|
+
) -> Tensor:
|
|
350
|
+
return torch.empty_like(q)
|
|
@@ -34,6 +34,7 @@ _import_structure = {
|
|
|
34
34
|
"RBLNAutoModelForSequenceClassification",
|
|
35
35
|
"RBLNAutoModelForSpeechSeq2Seq",
|
|
36
36
|
"RBLNAutoModelForVision2Seq",
|
|
37
|
+
"RBLNAutoModelForTextEncoding",
|
|
37
38
|
"RBLNBartForConditionalGeneration",
|
|
38
39
|
"RBLNBartForConditionalGenerationConfig",
|
|
39
40
|
"RBLNBartModel",
|
|
@@ -171,6 +172,7 @@ if TYPE_CHECKING:
|
|
|
171
172
|
RBLNAutoModelForSeq2SeqLM,
|
|
172
173
|
RBLNAutoModelForSequenceClassification,
|
|
173
174
|
RBLNAutoModelForSpeechSeq2Seq,
|
|
175
|
+
RBLNAutoModelForTextEncoding,
|
|
174
176
|
RBLNAutoModelForVision2Seq,
|
|
175
177
|
RBLNBartForConditionalGeneration,
|
|
176
178
|
RBLNBartForConditionalGenerationConfig,
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
from ..configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -25,7 +25,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
|
|
|
25
25
|
max_seq_len: Optional[int] = None,
|
|
26
26
|
batch_size: Optional[int] = None,
|
|
27
27
|
model_input_names: Optional[List[str]] = None,
|
|
28
|
-
**kwargs:
|
|
28
|
+
**kwargs: Any,
|
|
29
29
|
):
|
|
30
30
|
"""
|
|
31
31
|
Args:
|
|
@@ -52,7 +52,7 @@ class RBLNImageModelConfig(RBLNModelConfig):
|
|
|
52
52
|
self,
|
|
53
53
|
image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
54
54
|
batch_size: Optional[int] = None,
|
|
55
|
-
**kwargs:
|
|
55
|
+
**kwargs: Any,
|
|
56
56
|
):
|
|
57
57
|
"""
|
|
58
58
|
Args:
|
|
@@ -124,7 +124,7 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
|
|
|
124
124
|
batch_size: Optional[int] = None,
|
|
125
125
|
max_length: Optional[int] = None,
|
|
126
126
|
num_mel_bins: Optional[int] = None,
|
|
127
|
-
**kwargs:
|
|
127
|
+
**kwargs: Any,
|
|
128
128
|
):
|
|
129
129
|
"""
|
|
130
130
|
Args:
|
|
@@ -34,10 +34,7 @@ from transformers import (
|
|
|
34
34
|
AutoModelForTextEncoding,
|
|
35
35
|
PretrainedConfig,
|
|
36
36
|
)
|
|
37
|
-
from transformers.modeling_outputs import
|
|
38
|
-
BaseModelOutput,
|
|
39
|
-
QuestionAnsweringModelOutput,
|
|
40
|
-
)
|
|
37
|
+
from transformers.modeling_outputs import BaseModelOutput, QuestionAnsweringModelOutput
|
|
41
38
|
|
|
42
39
|
from ..configuration_utils import RBLNCompileConfig
|
|
43
40
|
from ..modeling import RBLNModel
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import Optional, Tuple
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import ModelOutput
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class RBLNDecoderOnlyOutput(ModelOutput):
|
|
24
|
+
logits: torch.FloatTensor = None
|
|
25
|
+
generate_idx: torch.Tensor = None
|
|
26
|
+
padded_cache_lengths: int = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
|
|
31
|
+
attention_mask: Optional[torch.Tensor] = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
|
|
36
|
+
last_hidden_states: torch.FloatTensor = None
|
|
37
|
+
params: Tuple[torch.FloatTensor] = None
|
|
@@ -36,6 +36,7 @@ _import_structure = {
|
|
|
36
36
|
"RBLNAutoModelForSpeechSeq2Seq",
|
|
37
37
|
"RBLNAutoModelForVision2Seq",
|
|
38
38
|
"RBLNAutoModelForImageTextToText",
|
|
39
|
+
"RBLNAutoModelForTextEncoding",
|
|
39
40
|
],
|
|
40
41
|
"bart": [
|
|
41
42
|
"RBLNBartForConditionalGeneration",
|
|
@@ -162,10 +163,7 @@ _import_structure = {
|
|
|
162
163
|
}
|
|
163
164
|
|
|
164
165
|
if TYPE_CHECKING:
|
|
165
|
-
from .audio_spectrogram_transformer import
|
|
166
|
-
RBLNASTForAudioClassification,
|
|
167
|
-
RBLNASTForAudioClassificationConfig,
|
|
168
|
-
)
|
|
166
|
+
from .audio_spectrogram_transformer import RBLNASTForAudioClassification, RBLNASTForAudioClassificationConfig
|
|
169
167
|
from .auto import (
|
|
170
168
|
RBLNAutoModel,
|
|
171
169
|
RBLNAutoModelForAudioClassification,
|
|
@@ -179,6 +177,7 @@ if TYPE_CHECKING:
|
|
|
179
177
|
RBLNAutoModelForSeq2SeqLM,
|
|
180
178
|
RBLNAutoModelForSequenceClassification,
|
|
181
179
|
RBLNAutoModelForSpeechSeq2Seq,
|
|
180
|
+
RBLNAutoModelForTextEncoding,
|
|
182
181
|
RBLNAutoModelForVision2Seq,
|
|
183
182
|
)
|
|
184
183
|
from .bart import (
|
|
@@ -213,24 +212,15 @@ if TYPE_CHECKING:
|
|
|
213
212
|
RBLNCLIPVisionModelWithProjection,
|
|
214
213
|
RBLNCLIPVisionModelWithProjectionConfig,
|
|
215
214
|
)
|
|
216
|
-
from .colpali import
|
|
217
|
-
RBLNColPaliForRetrieval,
|
|
218
|
-
RBLNColPaliForRetrievalConfig,
|
|
219
|
-
)
|
|
215
|
+
from .colpali import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
|
|
220
216
|
from .decoderonly import (
|
|
221
217
|
RBLNDecoderOnlyModel,
|
|
222
218
|
RBLNDecoderOnlyModelConfig,
|
|
223
219
|
RBLNDecoderOnlyModelForCausalLM,
|
|
224
220
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
225
221
|
)
|
|
226
|
-
from .distilbert import
|
|
227
|
-
|
|
228
|
-
RBLNDistilBertForQuestionAnsweringConfig,
|
|
229
|
-
)
|
|
230
|
-
from .dpt import (
|
|
231
|
-
RBLNDPTForDepthEstimation,
|
|
232
|
-
RBLNDPTForDepthEstimationConfig,
|
|
233
|
-
)
|
|
222
|
+
from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
|
|
223
|
+
from .dpt import RBLNDPTForDepthEstimation, RBLNDPTForDepthEstimationConfig
|
|
234
224
|
from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
|
|
235
225
|
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
|
|
236
226
|
from .gemma3 import (
|
|
@@ -35,6 +35,8 @@ from transformers.models.auto.modeling_auto import (
|
|
|
35
35
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
|
36
36
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
|
37
37
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
|
38
|
+
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
|
39
|
+
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
|
|
38
40
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
|
39
41
|
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
|
|
40
42
|
MODEL_MAPPING,
|
|
@@ -115,3 +117,8 @@ class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
|
|
|
115
117
|
class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
|
116
118
|
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
|
117
119
|
_model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
|
|
123
|
+
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
|
124
|
+
_model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
|
|
@@ -16,9 +16,7 @@ from typing import Tuple
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from torch import nn
|
|
19
|
-
from transformers.modeling_attn_mask_utils import
|
|
20
|
-
_prepare_4d_attention_mask,
|
|
21
|
-
)
|
|
19
|
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
22
20
|
from transformers.utils import logging
|
|
23
21
|
|
|
24
22
|
from ..seq2seq.seq2seq_architecture import (
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -62,7 +62,7 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
62
62
|
vision_model: Optional[RBLNModelConfig] = None,
|
|
63
63
|
qformer: Optional[RBLNModelConfig] = None,
|
|
64
64
|
language_model: Optional[RBLNModelConfig] = None,
|
|
65
|
-
**kwargs:
|
|
65
|
+
**kwargs: Any,
|
|
66
66
|
):
|
|
67
67
|
"""
|
|
68
68
|
Args:
|
|
@@ -35,11 +35,7 @@ from ....modeling import RBLNModel
|
|
|
35
35
|
logger = logging.get_logger(__name__)
|
|
36
36
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
|
-
from transformers import
|
|
39
|
-
AutoFeatureExtractor,
|
|
40
|
-
AutoProcessor,
|
|
41
|
-
AutoTokenizer,
|
|
42
|
-
)
|
|
38
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
43
39
|
|
|
44
40
|
|
|
45
41
|
class LoopProjector:
|
|
@@ -12,13 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class RBLNCLIPTextModelConfig(RBLNModelConfig):
|
|
21
|
-
def __init__(self, batch_size: Optional[int] = None, **kwargs:
|
|
21
|
+
def __init__(self, batch_size: Optional[int] = None, **kwargs: Any):
|
|
22
22
|
"""
|
|
23
23
|
Args:
|
|
24
24
|
batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
|
|
@@ -50,7 +50,7 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
|
|
|
50
50
|
interpolate_pos_encoding: Optional[bool] = None,
|
|
51
51
|
output_hidden_states: Optional[bool] = None,
|
|
52
52
|
output_attentions: Optional[bool] = None,
|
|
53
|
-
**kwargs:
|
|
53
|
+
**kwargs: Any,
|
|
54
54
|
):
|
|
55
55
|
"""
|
|
56
56
|
Args:
|
|
@@ -4,10 +4,7 @@ import torch
|
|
|
4
4
|
from torch import nn
|
|
5
5
|
from transformers import GemmaForCausalLM, GemmaModel
|
|
6
6
|
|
|
7
|
-
from ..decoderonly.decoderonly_architecture import
|
|
8
|
-
RotaryEmbedding,
|
|
9
|
-
apply_rotary_pos_emb,
|
|
10
|
-
)
|
|
7
|
+
from ..decoderonly.decoderonly_architecture import RotaryEmbedding, apply_rotary_pos_emb
|
|
11
8
|
|
|
12
9
|
|
|
13
10
|
def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from typing import Any,
|
|
14
|
+
from typing import Any, List, Optional, Union
|
|
15
15
|
|
|
16
16
|
from ....configuration_utils import RBLNModelConfig
|
|
17
17
|
|
|
@@ -50,7 +50,7 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
|
|
|
50
50
|
max_seq_lens: Union[int, List[int]] = None,
|
|
51
51
|
output_hidden_states: Optional[bool] = None,
|
|
52
52
|
vision_tower: Optional[RBLNModelConfig] = None,
|
|
53
|
-
**kwargs:
|
|
53
|
+
**kwargs: Any,
|
|
54
54
|
):
|
|
55
55
|
"""
|
|
56
56
|
Args:
|