optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/ops/attn.py
CHANGED
|
@@ -53,6 +53,45 @@ def paged_attn_decode_fake(
|
|
|
53
53
|
return torch.empty_like(q)
|
|
54
54
|
|
|
55
55
|
|
|
56
|
+
@torch.library.custom_op(
|
|
57
|
+
"rbln_custom_ops::paged_attn_decode_kv_fp8",
|
|
58
|
+
mutates_args=(["kcache", "vcache"]),
|
|
59
|
+
)
|
|
60
|
+
def paged_attn_decode_kv_fp8(
|
|
61
|
+
q: Tensor,
|
|
62
|
+
k: Tensor,
|
|
63
|
+
v: Tensor,
|
|
64
|
+
mask: Tensor,
|
|
65
|
+
kcache: Tensor,
|
|
66
|
+
vcache: Tensor,
|
|
67
|
+
seq: Tensor,
|
|
68
|
+
scale: Tensor,
|
|
69
|
+
block_table: Tensor,
|
|
70
|
+
block_size: int,
|
|
71
|
+
k_scale: Tensor,
|
|
72
|
+
v_scale: Tensor,
|
|
73
|
+
) -> Tensor:
|
|
74
|
+
return torch.empty_like(q)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@paged_attn_decode_kv_fp8.register_fake
|
|
78
|
+
def paged_attn_decode_kv_fp8_fake(
|
|
79
|
+
q: Tensor,
|
|
80
|
+
k: Tensor,
|
|
81
|
+
v: Tensor,
|
|
82
|
+
mask: Tensor,
|
|
83
|
+
kcache: Tensor,
|
|
84
|
+
vcache: Tensor,
|
|
85
|
+
seq: Tensor,
|
|
86
|
+
scale: Tensor,
|
|
87
|
+
block_table: Tensor,
|
|
88
|
+
block_size: int,
|
|
89
|
+
k_scale: Tensor,
|
|
90
|
+
v_scale: Tensor,
|
|
91
|
+
) -> Tensor:
|
|
92
|
+
return torch.empty_like(q)
|
|
93
|
+
|
|
94
|
+
|
|
56
95
|
@torch.library.custom_op(
|
|
57
96
|
"rbln_custom_ops::paged_attn_prefill",
|
|
58
97
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -112,6 +151,45 @@ def paged_attn_prefill_fake(
|
|
|
112
151
|
return torch.empty_like(q)
|
|
113
152
|
|
|
114
153
|
|
|
154
|
+
@torch.library.custom_op(
|
|
155
|
+
"rbln_custom_ops::paged_attn_prefill_kv_fp8",
|
|
156
|
+
mutates_args=(["kcache", "vcache"]),
|
|
157
|
+
)
|
|
158
|
+
def paged_attn_prefill_kv_fp8(
|
|
159
|
+
q: Tensor,
|
|
160
|
+
k: Tensor,
|
|
161
|
+
v: Tensor,
|
|
162
|
+
mask: Tensor,
|
|
163
|
+
kcache: Tensor,
|
|
164
|
+
vcache: Tensor,
|
|
165
|
+
seq: Tensor,
|
|
166
|
+
scale: Tensor,
|
|
167
|
+
block_table: Tensor,
|
|
168
|
+
block_size: int,
|
|
169
|
+
k_scale: Tensor,
|
|
170
|
+
v_scale: Tensor,
|
|
171
|
+
) -> Tensor:
|
|
172
|
+
return torch.empty_like(q)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@paged_attn_prefill_kv_fp8.register_fake
|
|
176
|
+
def paged_attn_prefill_kv_fp8_fake(
|
|
177
|
+
q: Tensor,
|
|
178
|
+
k: Tensor,
|
|
179
|
+
v: Tensor,
|
|
180
|
+
mask: Tensor,
|
|
181
|
+
kcache: Tensor,
|
|
182
|
+
vcache: Tensor,
|
|
183
|
+
seq: Tensor,
|
|
184
|
+
scale: Tensor,
|
|
185
|
+
block_table: Tensor,
|
|
186
|
+
block_size: int,
|
|
187
|
+
k_scale: Tensor,
|
|
188
|
+
v_scale: Tensor,
|
|
189
|
+
) -> Tensor:
|
|
190
|
+
return torch.empty_like(q)
|
|
191
|
+
|
|
192
|
+
|
|
115
193
|
@torch.library.custom_op(
|
|
116
194
|
"rbln_custom_ops::paged_causal_attn_decode",
|
|
117
195
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -236,6 +314,86 @@ def paged_causal_attn_prefill_fake(
|
|
|
236
314
|
return torch.empty_like(q)
|
|
237
315
|
|
|
238
316
|
|
|
317
|
+
@torch.library.custom_op(
|
|
318
|
+
"rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
|
|
319
|
+
mutates_args=(["kcache", "vcache"]),
|
|
320
|
+
)
|
|
321
|
+
def paged_causal_attn_decode_kv_fp8(
|
|
322
|
+
q: Tensor,
|
|
323
|
+
k: Tensor,
|
|
324
|
+
v: Tensor,
|
|
325
|
+
kcache: Tensor,
|
|
326
|
+
vcache: Tensor,
|
|
327
|
+
seq: Tensor,
|
|
328
|
+
scale: Tensor,
|
|
329
|
+
block_table: Tensor,
|
|
330
|
+
block_size: int,
|
|
331
|
+
k_scale: Tensor,
|
|
332
|
+
v_scale: Tensor,
|
|
333
|
+
mask: Optional[Tensor] = None,
|
|
334
|
+
) -> Tensor:
|
|
335
|
+
return torch.empty_like(q)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@paged_causal_attn_decode_kv_fp8.register_fake
|
|
339
|
+
def paged_causal_attn_decode_kv_fp8_fake(
|
|
340
|
+
q: Tensor,
|
|
341
|
+
k: Tensor,
|
|
342
|
+
v: Tensor,
|
|
343
|
+
kcache: Tensor,
|
|
344
|
+
vcache: Tensor,
|
|
345
|
+
seq: Tensor,
|
|
346
|
+
scale: Tensor,
|
|
347
|
+
block_table: Tensor,
|
|
348
|
+
block_size: int,
|
|
349
|
+
k_scale: Tensor,
|
|
350
|
+
v_scale: Tensor,
|
|
351
|
+
mask: Optional[Tensor] = None,
|
|
352
|
+
) -> Tensor:
|
|
353
|
+
return torch.empty_like(q)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@torch.library.custom_op(
|
|
357
|
+
"rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
|
|
358
|
+
mutates_args=(["kcache", "vcache"]),
|
|
359
|
+
)
|
|
360
|
+
def paged_causal_attn_prefill_kv_fp8(
|
|
361
|
+
q: Tensor,
|
|
362
|
+
k: Tensor,
|
|
363
|
+
v: Tensor,
|
|
364
|
+
kcache: Tensor,
|
|
365
|
+
vcache: Tensor,
|
|
366
|
+
seq: Tensor,
|
|
367
|
+
scale: Tensor,
|
|
368
|
+
block_table: Tensor,
|
|
369
|
+
block_size: int,
|
|
370
|
+
is_bidirectional: bool,
|
|
371
|
+
k_scale: Tensor,
|
|
372
|
+
v_scale: Tensor,
|
|
373
|
+
mask: Optional[Tensor] = None,
|
|
374
|
+
) -> Tensor:
|
|
375
|
+
return torch.empty_like(q)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@paged_causal_attn_prefill_kv_fp8.register_fake
|
|
379
|
+
def paged_causal_attn_prefill_kv_fp8_fake(
|
|
380
|
+
q: Tensor,
|
|
381
|
+
k: Tensor,
|
|
382
|
+
v: Tensor,
|
|
383
|
+
kcache: Tensor,
|
|
384
|
+
vcache: Tensor,
|
|
385
|
+
seq: Tensor,
|
|
386
|
+
scale: Tensor,
|
|
387
|
+
block_table: Tensor,
|
|
388
|
+
block_size: int,
|
|
389
|
+
is_bidirectional: bool,
|
|
390
|
+
k_scale: Tensor,
|
|
391
|
+
v_scale: Tensor,
|
|
392
|
+
mask: Optional[Tensor] = None,
|
|
393
|
+
) -> Tensor:
|
|
394
|
+
return torch.empty_like(q)
|
|
395
|
+
|
|
396
|
+
|
|
239
397
|
@torch.library.custom_op(
|
|
240
398
|
"rbln_custom_ops::paged_add_softmax_attn_decode",
|
|
241
399
|
mutates_args=(["kcache", "vcache"]),
|
optimum/rbln/ops/flash_attn.py
CHANGED
|
@@ -59,6 +59,47 @@ def paged_flash_attn_decode_fake(
|
|
|
59
59
|
return torch.empty_like(q)
|
|
60
60
|
|
|
61
61
|
|
|
62
|
+
@torch.library.custom_op(
|
|
63
|
+
"rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
|
|
64
|
+
mutates_args=(["kcache", "vcache"]),
|
|
65
|
+
)
|
|
66
|
+
def paged_flash_attn_decode_kv_fp8(
|
|
67
|
+
q: Tensor,
|
|
68
|
+
k: Tensor,
|
|
69
|
+
v: Tensor,
|
|
70
|
+
mask: Tensor,
|
|
71
|
+
kcache: Tensor,
|
|
72
|
+
vcache: Tensor,
|
|
73
|
+
seq: Tensor,
|
|
74
|
+
scale: Tensor,
|
|
75
|
+
block_table: Tensor,
|
|
76
|
+
block_size: int,
|
|
77
|
+
partition: int,
|
|
78
|
+
k_scale: Tensor,
|
|
79
|
+
v_scale: Tensor,
|
|
80
|
+
) -> Tensor:
|
|
81
|
+
return torch.empty_like(q)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@paged_flash_attn_decode_kv_fp8.register_fake
|
|
85
|
+
def paged_flash_attn_decode_kv_fp8_fake(
|
|
86
|
+
q: Tensor,
|
|
87
|
+
k: Tensor,
|
|
88
|
+
v: Tensor,
|
|
89
|
+
mask: Tensor,
|
|
90
|
+
kcache: Tensor,
|
|
91
|
+
vcache: Tensor,
|
|
92
|
+
seq: Tensor,
|
|
93
|
+
scale: Tensor,
|
|
94
|
+
block_table: Tensor,
|
|
95
|
+
block_size: int,
|
|
96
|
+
partition: int,
|
|
97
|
+
k_scale: Tensor,
|
|
98
|
+
v_scale: Tensor,
|
|
99
|
+
) -> Tensor:
|
|
100
|
+
return torch.empty_like(q)
|
|
101
|
+
|
|
102
|
+
|
|
62
103
|
@torch.library.custom_op(
|
|
63
104
|
"rbln_custom_ops::paged_flash_attn_prefill",
|
|
64
105
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -100,6 +141,47 @@ def paged_flash_attn_prefill_fake(
|
|
|
100
141
|
return torch.empty_like(q)
|
|
101
142
|
|
|
102
143
|
|
|
144
|
+
@torch.library.custom_op(
|
|
145
|
+
"rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
|
|
146
|
+
mutates_args=(["kcache", "vcache"]),
|
|
147
|
+
)
|
|
148
|
+
def paged_flash_attn_prefill_kv_fp8(
|
|
149
|
+
q: Tensor,
|
|
150
|
+
k: Tensor,
|
|
151
|
+
v: Tensor,
|
|
152
|
+
mask: Tensor,
|
|
153
|
+
kcache: Tensor,
|
|
154
|
+
vcache: Tensor,
|
|
155
|
+
seq: Tensor,
|
|
156
|
+
scale: Tensor,
|
|
157
|
+
block_table: Tensor,
|
|
158
|
+
block_size: int,
|
|
159
|
+
partition: int,
|
|
160
|
+
k_scale: Tensor,
|
|
161
|
+
v_scale: Tensor,
|
|
162
|
+
) -> Tensor:
|
|
163
|
+
return torch.empty_like(q)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@paged_flash_attn_prefill_kv_fp8.register_fake
|
|
167
|
+
def paged_flash_attn_prefill_kv_fp8_fake(
|
|
168
|
+
q: Tensor,
|
|
169
|
+
k: Tensor,
|
|
170
|
+
v: Tensor,
|
|
171
|
+
mask: Tensor,
|
|
172
|
+
kcache: Tensor,
|
|
173
|
+
vcache: Tensor,
|
|
174
|
+
seq: Tensor,
|
|
175
|
+
scale: Tensor,
|
|
176
|
+
block_table: Tensor,
|
|
177
|
+
block_size: int,
|
|
178
|
+
partition: int,
|
|
179
|
+
k_scale: Tensor,
|
|
180
|
+
v_scale: Tensor,
|
|
181
|
+
) -> Tensor:
|
|
182
|
+
return torch.empty_like(q)
|
|
183
|
+
|
|
184
|
+
|
|
103
185
|
@torch.library.custom_op(
|
|
104
186
|
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
|
105
187
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -141,6 +223,47 @@ def paged_flash_causal_attn_decode_fake(
|
|
|
141
223
|
return torch.empty_like(q)
|
|
142
224
|
|
|
143
225
|
|
|
226
|
+
@torch.library.custom_op(
|
|
227
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
|
|
228
|
+
mutates_args=(["kcache", "vcache"]),
|
|
229
|
+
)
|
|
230
|
+
def paged_flash_causal_attn_decode_kv_fp8(
|
|
231
|
+
q: Tensor,
|
|
232
|
+
k: Tensor,
|
|
233
|
+
v: Tensor,
|
|
234
|
+
kcache: Tensor,
|
|
235
|
+
vcache: Tensor,
|
|
236
|
+
seq: Tensor,
|
|
237
|
+
scale: Tensor,
|
|
238
|
+
block_table: Tensor,
|
|
239
|
+
block_size: int,
|
|
240
|
+
partition: int,
|
|
241
|
+
k_scale: Tensor,
|
|
242
|
+
v_scale: Tensor,
|
|
243
|
+
mask: Optional[Tensor] = None,
|
|
244
|
+
) -> Tensor:
|
|
245
|
+
return torch.empty_like(q)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@paged_flash_causal_attn_decode_kv_fp8.register_fake
|
|
249
|
+
def paged_flash_causal_attn_decode_kv_fp8_fake(
|
|
250
|
+
q: Tensor,
|
|
251
|
+
k: Tensor,
|
|
252
|
+
v: Tensor,
|
|
253
|
+
kcache: Tensor,
|
|
254
|
+
vcache: Tensor,
|
|
255
|
+
seq: Tensor,
|
|
256
|
+
scale: Tensor,
|
|
257
|
+
block_table: Tensor,
|
|
258
|
+
block_size: int,
|
|
259
|
+
partition: int,
|
|
260
|
+
k_scale: Tensor,
|
|
261
|
+
v_scale: Tensor,
|
|
262
|
+
mask: Optional[Tensor] = None,
|
|
263
|
+
) -> Tensor:
|
|
264
|
+
return torch.empty_like(q)
|
|
265
|
+
|
|
266
|
+
|
|
144
267
|
@torch.library.custom_op(
|
|
145
268
|
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
|
146
269
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -182,3 +305,46 @@ def paged_flash_causal_attn_prefill_fake(
|
|
|
182
305
|
mask: Optional[Tensor] = None,
|
|
183
306
|
) -> Tensor:
|
|
184
307
|
return torch.empty_like(q)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@torch.library.custom_op(
|
|
311
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
|
|
312
|
+
mutates_args=(["kcache", "vcache"]),
|
|
313
|
+
)
|
|
314
|
+
def paged_flash_causal_attn_prefill_kv_fp8(
|
|
315
|
+
q: Tensor,
|
|
316
|
+
k: Tensor,
|
|
317
|
+
v: Tensor,
|
|
318
|
+
kcache: Tensor,
|
|
319
|
+
vcache: Tensor,
|
|
320
|
+
seq: Tensor,
|
|
321
|
+
scale: Tensor,
|
|
322
|
+
block_table: Tensor,
|
|
323
|
+
block_size: int,
|
|
324
|
+
partition: int,
|
|
325
|
+
is_bidirectional: bool,
|
|
326
|
+
k_scale: Tensor,
|
|
327
|
+
v_scale: Tensor,
|
|
328
|
+
mask: Optional[Tensor] = None,
|
|
329
|
+
) -> Tensor:
|
|
330
|
+
return torch.empty_like(q)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@paged_flash_causal_attn_prefill_kv_fp8.register_fake
|
|
334
|
+
def paged_flash_causal_attn_prefill_kv_fp8_fake(
|
|
335
|
+
q: Tensor,
|
|
336
|
+
k: Tensor,
|
|
337
|
+
v: Tensor,
|
|
338
|
+
kcache: Tensor,
|
|
339
|
+
vcache: Tensor,
|
|
340
|
+
seq: Tensor,
|
|
341
|
+
scale: Tensor,
|
|
342
|
+
block_table: Tensor,
|
|
343
|
+
block_size: int,
|
|
344
|
+
partition: int,
|
|
345
|
+
is_bidirectional: bool,
|
|
346
|
+
k_scale: Tensor,
|
|
347
|
+
v_scale: Tensor,
|
|
348
|
+
mask: Optional[Tensor] = None,
|
|
349
|
+
) -> Tensor:
|
|
350
|
+
return torch.empty_like(q)
|
|
@@ -22,3 +22,8 @@ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tens
|
|
|
22
22
|
# This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
|
|
23
23
|
# The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
|
|
24
24
|
return torch.empty_like(cache)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@rbln_cache_update.register_fake
|
|
28
|
+
def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
|
|
29
|
+
return torch.empty_like(cache)
|
optimum/rbln/ops/linear.py
CHANGED
|
@@ -23,3 +23,10 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
|
|
|
23
23
|
output_shape = list(input.shape[:-1])
|
|
24
24
|
output_shape += [weight.shape[0]]
|
|
25
25
|
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@linear.register_fake
|
|
29
|
+
def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
|
|
30
|
+
output_shape = list(input.shape[:-1])
|
|
31
|
+
output_shape += [weight.shape[0]]
|
|
32
|
+
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
|
@@ -34,6 +34,8 @@ _import_structure = {
|
|
|
34
34
|
"RBLNAutoModelForSequenceClassification",
|
|
35
35
|
"RBLNAutoModelForSpeechSeq2Seq",
|
|
36
36
|
"RBLNAutoModelForVision2Seq",
|
|
37
|
+
"RBLNAutoModelForTextEncoding",
|
|
38
|
+
"RBLNAutoModelForZeroShotObjectDetection",
|
|
37
39
|
"RBLNBartForConditionalGeneration",
|
|
38
40
|
"RBLNBartForConditionalGenerationConfig",
|
|
39
41
|
"RBLNBartModel",
|
|
@@ -52,6 +54,8 @@ _import_structure = {
|
|
|
52
54
|
"RBLNBlip2VisionModelConfig",
|
|
53
55
|
"RBLNColPaliForRetrieval",
|
|
54
56
|
"RBLNColPaliForRetrievalConfig",
|
|
57
|
+
"RBLNColQwen2ForRetrieval",
|
|
58
|
+
"RBLNColQwen2ForRetrievalConfig",
|
|
55
59
|
"RBLNCLIPTextModel",
|
|
56
60
|
"RBLNCLIPTextModelConfig",
|
|
57
61
|
"RBLNCLIPTextModelWithProjection",
|
|
@@ -62,12 +66,18 @@ _import_structure = {
|
|
|
62
66
|
"RBLNCLIPVisionModelWithProjectionConfig",
|
|
63
67
|
"RBLNDecoderOnlyModelForCausalLM",
|
|
64
68
|
"RBLNDecoderOnlyModelForCausalLMConfig",
|
|
69
|
+
"RBLNDecoderOnlyModelConfig",
|
|
70
|
+
"RBLNDecoderOnlyModel",
|
|
65
71
|
"RBLNDistilBertForQuestionAnswering",
|
|
66
72
|
"RBLNDistilBertForQuestionAnsweringConfig",
|
|
67
73
|
"RBLNDPTForDepthEstimation",
|
|
68
74
|
"RBLNDPTForDepthEstimationConfig",
|
|
75
|
+
"RBLNDepthAnythingForDepthEstimation",
|
|
76
|
+
"RBLNDepthAnythingForDepthEstimationConfig",
|
|
69
77
|
"RBLNExaoneForCausalLM",
|
|
70
78
|
"RBLNExaoneForCausalLMConfig",
|
|
79
|
+
"RBLNGemmaModel",
|
|
80
|
+
"RBLNGemmaModelConfig",
|
|
71
81
|
"RBLNGemma3ForCausalLM",
|
|
72
82
|
"RBLNGemma3ForCausalLMConfig",
|
|
73
83
|
"RBLNGemma3ForConditionalGeneration",
|
|
@@ -76,26 +86,60 @@ _import_structure = {
|
|
|
76
86
|
"RBLNGemmaForCausalLMConfig",
|
|
77
87
|
"RBLNGPT2LMHeadModel",
|
|
78
88
|
"RBLNGPT2LMHeadModelConfig",
|
|
89
|
+
"RBLNGPT2Model",
|
|
90
|
+
"RBLNGPT2ModelConfig",
|
|
91
|
+
"RBLNGroundingDinoDecoder",
|
|
92
|
+
"RBLNGroundingDinoDecoderConfig",
|
|
93
|
+
"RBLNGroundingDinoForObjectDetection",
|
|
94
|
+
"RBLNGroundingDinoForObjectDetectionConfig",
|
|
95
|
+
"RBLNGroundingDinoEncoder",
|
|
96
|
+
"RBLNGroundingDinoEncoderConfig",
|
|
79
97
|
"RBLNIdefics3ForConditionalGeneration",
|
|
80
98
|
"RBLNIdefics3ForConditionalGenerationConfig",
|
|
81
99
|
"RBLNIdefics3VisionTransformer",
|
|
82
100
|
"RBLNIdefics3VisionTransformerConfig",
|
|
83
101
|
"RBLNLlamaForCausalLM",
|
|
84
102
|
"RBLNLlamaForCausalLMConfig",
|
|
103
|
+
"RBLNLlavaForConditionalGeneration",
|
|
104
|
+
"RBLNLlavaForConditionalGenerationConfig",
|
|
105
|
+
"RBLNLlamaModel",
|
|
106
|
+
"RBLNLlamaModelConfig",
|
|
107
|
+
"RBLNOPTForCausalLM",
|
|
108
|
+
"RBLNOPTForCausalLMConfig",
|
|
109
|
+
"RBLNPegasusForConditionalGeneration",
|
|
110
|
+
"RBLNPegasusForConditionalGenerationConfig",
|
|
111
|
+
"RBLNPegasusModel",
|
|
112
|
+
"RBLNPegasusModelConfig",
|
|
85
113
|
"RBLNLlavaNextForConditionalGeneration",
|
|
86
114
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
|
115
|
+
"RBLNLoRAAdapterConfig",
|
|
116
|
+
"RBLNLoRAConfig",
|
|
87
117
|
"RBLNMidmLMHeadModel",
|
|
88
118
|
"RBLNMidmLMHeadModelConfig",
|
|
89
119
|
"RBLNMistralForCausalLM",
|
|
90
120
|
"RBLNMistralForCausalLMConfig",
|
|
121
|
+
"RBLNMistralModel",
|
|
122
|
+
"RBLNMistralModelConfig",
|
|
91
123
|
"RBLNOPTForCausalLM",
|
|
92
124
|
"RBLNOPTForCausalLMConfig",
|
|
125
|
+
"RBLNOPTModel",
|
|
126
|
+
"RBLNOPTModelConfig",
|
|
93
127
|
"RBLNPhiForCausalLM",
|
|
94
128
|
"RBLNPhiForCausalLMConfig",
|
|
129
|
+
"RBLNPixtralVisionModelConfig",
|
|
130
|
+
"RBLNPixtralVisionModel",
|
|
131
|
+
"RBLNPhiModel",
|
|
132
|
+
"RBLNPhiModelConfig",
|
|
95
133
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
|
96
134
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
97
135
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
98
136
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
137
|
+
"RBLNQwen2VisionTransformerPretrainedModel",
|
|
138
|
+
"RBLNQwen2VisionTransformerPretrainedModelConfig",
|
|
139
|
+
"RBLNQwen2VLForConditionalGeneration",
|
|
140
|
+
"RBLNQwen2VLForConditionalGenerationConfig",
|
|
141
|
+
"RBLNQwen2Model",
|
|
142
|
+
"RBLNQwen2ModelConfig",
|
|
99
143
|
"RBLNQwen2ForCausalLM",
|
|
100
144
|
"RBLNQwen2ForCausalLMConfig",
|
|
101
145
|
"RBLNQwen3ForCausalLM",
|
|
@@ -110,6 +154,8 @@ _import_structure = {
|
|
|
110
154
|
"RBLNRobertaForSequenceClassificationConfig",
|
|
111
155
|
"RBLNSiglipVisionModel",
|
|
112
156
|
"RBLNSiglipVisionModelConfig",
|
|
157
|
+
"RBLNSwinBackbone",
|
|
158
|
+
"RBLNSwinBackboneConfig",
|
|
113
159
|
"RBLNT5EncoderModel",
|
|
114
160
|
"RBLNT5EncoderModelConfig",
|
|
115
161
|
"RBLNT5ForConditionalGeneration",
|
|
@@ -145,7 +191,9 @@ if TYPE_CHECKING:
|
|
|
145
191
|
RBLNAutoModelForSeq2SeqLM,
|
|
146
192
|
RBLNAutoModelForSequenceClassification,
|
|
147
193
|
RBLNAutoModelForSpeechSeq2Seq,
|
|
194
|
+
RBLNAutoModelForTextEncoding,
|
|
148
195
|
RBLNAutoModelForVision2Seq,
|
|
196
|
+
RBLNAutoModelForZeroShotObjectDetection,
|
|
149
197
|
RBLNBartForConditionalGeneration,
|
|
150
198
|
RBLNBartForConditionalGenerationConfig,
|
|
151
199
|
RBLNBartModel,
|
|
@@ -170,8 +218,16 @@ if TYPE_CHECKING:
|
|
|
170
218
|
RBLNCLIPVisionModelConfig,
|
|
171
219
|
RBLNCLIPVisionModelWithProjection,
|
|
172
220
|
RBLNCLIPVisionModelWithProjectionConfig,
|
|
221
|
+
RBLNColPaliForRetrieval,
|
|
222
|
+
RBLNColPaliForRetrievalConfig,
|
|
223
|
+
RBLNColQwen2ForRetrieval,
|
|
224
|
+
RBLNColQwen2ForRetrievalConfig,
|
|
225
|
+
RBLNDecoderOnlyModel,
|
|
226
|
+
RBLNDecoderOnlyModelConfig,
|
|
173
227
|
RBLNDecoderOnlyModelForCausalLM,
|
|
174
228
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
229
|
+
RBLNDepthAnythingForDepthEstimation,
|
|
230
|
+
RBLNDepthAnythingForDepthEstimationConfig,
|
|
175
231
|
RBLNDistilBertForQuestionAnswering,
|
|
176
232
|
RBLNDistilBertForQuestionAnsweringConfig,
|
|
177
233
|
RBLNDPTForDepthEstimation,
|
|
@@ -184,30 +240,64 @@ if TYPE_CHECKING:
|
|
|
184
240
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
185
241
|
RBLNGemmaForCausalLM,
|
|
186
242
|
RBLNGemmaForCausalLMConfig,
|
|
243
|
+
RBLNGemmaModel,
|
|
244
|
+
RBLNGemmaModelConfig,
|
|
187
245
|
RBLNGPT2LMHeadModel,
|
|
188
246
|
RBLNGPT2LMHeadModelConfig,
|
|
247
|
+
RBLNGPT2Model,
|
|
248
|
+
RBLNGPT2ModelConfig,
|
|
249
|
+
RBLNGroundingDinoDecoder,
|
|
250
|
+
RBLNGroundingDinoDecoderConfig,
|
|
251
|
+
RBLNGroundingDinoEncoder,
|
|
252
|
+
RBLNGroundingDinoEncoderConfig,
|
|
253
|
+
RBLNGroundingDinoForObjectDetection,
|
|
254
|
+
RBLNGroundingDinoForObjectDetectionConfig,
|
|
189
255
|
RBLNIdefics3ForConditionalGeneration,
|
|
190
256
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
191
257
|
RBLNIdefics3VisionTransformer,
|
|
192
258
|
RBLNIdefics3VisionTransformerConfig,
|
|
193
259
|
RBLNLlamaForCausalLM,
|
|
194
260
|
RBLNLlamaForCausalLMConfig,
|
|
261
|
+
RBLNLlamaModel,
|
|
262
|
+
RBLNLlamaModelConfig,
|
|
263
|
+
RBLNLlavaForConditionalGeneration,
|
|
264
|
+
RBLNLlavaForConditionalGenerationConfig,
|
|
195
265
|
RBLNLlavaNextForConditionalGeneration,
|
|
196
266
|
RBLNLlavaNextForConditionalGenerationConfig,
|
|
267
|
+
RBLNLoRAAdapterConfig,
|
|
268
|
+
RBLNLoRAConfig,
|
|
197
269
|
RBLNMidmLMHeadModel,
|
|
198
270
|
RBLNMidmLMHeadModelConfig,
|
|
199
271
|
RBLNMistralForCausalLM,
|
|
200
272
|
RBLNMistralForCausalLMConfig,
|
|
273
|
+
RBLNMistralModel,
|
|
274
|
+
RBLNMistralModelConfig,
|
|
201
275
|
RBLNOPTForCausalLM,
|
|
202
276
|
RBLNOPTForCausalLMConfig,
|
|
277
|
+
RBLNOPTModel,
|
|
278
|
+
RBLNOPTModelConfig,
|
|
279
|
+
RBLNPegasusForConditionalGeneration,
|
|
280
|
+
RBLNPegasusForConditionalGenerationConfig,
|
|
281
|
+
RBLNPegasusModel,
|
|
282
|
+
RBLNPegasusModelConfig,
|
|
203
283
|
RBLNPhiForCausalLM,
|
|
204
284
|
RBLNPhiForCausalLMConfig,
|
|
285
|
+
RBLNPhiModel,
|
|
286
|
+
RBLNPhiModelConfig,
|
|
287
|
+
RBLNPixtralVisionModel,
|
|
288
|
+
RBLNPixtralVisionModelConfig,
|
|
205
289
|
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
|
206
290
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
207
291
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
208
292
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
209
293
|
RBLNQwen2ForCausalLM,
|
|
210
294
|
RBLNQwen2ForCausalLMConfig,
|
|
295
|
+
RBLNQwen2Model,
|
|
296
|
+
RBLNQwen2ModelConfig,
|
|
297
|
+
RBLNQwen2VisionTransformerPretrainedModel,
|
|
298
|
+
RBLNQwen2VisionTransformerPretrainedModelConfig,
|
|
299
|
+
RBLNQwen2VLForConditionalGeneration,
|
|
300
|
+
RBLNQwen2VLForConditionalGenerationConfig,
|
|
211
301
|
RBLNQwen3ForCausalLM,
|
|
212
302
|
RBLNQwen3ForCausalLMConfig,
|
|
213
303
|
RBLNQwen3Model,
|
|
@@ -220,6 +310,8 @@ if TYPE_CHECKING:
|
|
|
220
310
|
RBLNRobertaForSequenceClassificationConfig,
|
|
221
311
|
RBLNSiglipVisionModel,
|
|
222
312
|
RBLNSiglipVisionModelConfig,
|
|
313
|
+
RBLNSwinBackbone,
|
|
314
|
+
RBLNSwinBackboneConfig,
|
|
223
315
|
RBLNT5EncoderModel,
|
|
224
316
|
RBLNT5EncoderModelConfig,
|
|
225
317
|
RBLNT5ForConditionalGeneration,
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
from ..configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
@@ -25,7 +25,8 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
|
|
|
25
25
|
max_seq_len: Optional[int] = None,
|
|
26
26
|
batch_size: Optional[int] = None,
|
|
27
27
|
model_input_names: Optional[List[str]] = None,
|
|
28
|
-
|
|
28
|
+
model_input_shapes: Optional[List[Tuple[int, int]]] = None,
|
|
29
|
+
**kwargs: Any,
|
|
29
30
|
):
|
|
30
31
|
"""
|
|
31
32
|
Args:
|
|
@@ -33,7 +34,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
|
|
|
33
34
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
34
35
|
model_input_names (Optional[List[str]]): Names of the input tensors for the model.
|
|
35
36
|
Defaults to class-specific rbln_model_input_names if not provided.
|
|
36
|
-
|
|
37
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
37
38
|
|
|
38
39
|
Raises:
|
|
39
40
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -45,6 +46,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
|
|
|
45
46
|
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
46
47
|
|
|
47
48
|
self.model_input_names = model_input_names or self.rbln_model_input_names
|
|
49
|
+
self.model_input_shapes = model_input_shapes
|
|
48
50
|
|
|
49
51
|
|
|
50
52
|
class RBLNImageModelConfig(RBLNModelConfig):
|
|
@@ -52,14 +54,14 @@ class RBLNImageModelConfig(RBLNModelConfig):
|
|
|
52
54
|
self,
|
|
53
55
|
image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
54
56
|
batch_size: Optional[int] = None,
|
|
55
|
-
**kwargs:
|
|
57
|
+
**kwargs: Any,
|
|
56
58
|
):
|
|
57
59
|
"""
|
|
58
60
|
Args:
|
|
59
61
|
image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
|
|
60
62
|
Can be an integer for square images or a tuple (height, width).
|
|
61
63
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
62
|
-
|
|
64
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
63
65
|
|
|
64
66
|
Raises:
|
|
65
67
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -124,14 +126,14 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
|
|
|
124
126
|
batch_size: Optional[int] = None,
|
|
125
127
|
max_length: Optional[int] = None,
|
|
126
128
|
num_mel_bins: Optional[int] = None,
|
|
127
|
-
**kwargs:
|
|
129
|
+
**kwargs: Any,
|
|
128
130
|
):
|
|
129
131
|
"""
|
|
130
132
|
Args:
|
|
131
133
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
132
134
|
max_length (Optional[int]): Maximum length of the audio input in time dimension.
|
|
133
135
|
num_mel_bins (Optional[int]): Number of Mel frequency bins for audio processing.
|
|
134
|
-
|
|
136
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
135
137
|
|
|
136
138
|
Raises:
|
|
137
139
|
ValueError: If batch_size is not a positive integer.
|