optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +24 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
- optimum/rbln/diffusers/modeling_diffusers.py +72 -65
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
- optimum/rbln/diffusers/models/controlnet.py +14 -8
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +71 -37
- optimum/rbln/modeling_base.py +63 -109
- optimum/rbln/transformers/__init__.py +41 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +21 -22
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +54 -4
- optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/runtime_utils.py +49 -1
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
- optimum_rbln-0.8.1.dist-info/RECORD +211 -0
- optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_base.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
import importlib
|
16
16
|
import os
|
17
17
|
import shutil
|
18
|
-
from abc import ABC
|
18
|
+
from abc import ABC
|
19
19
|
from pathlib import Path
|
20
20
|
from tempfile import TemporaryDirectory
|
21
21
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
@@ -24,10 +24,10 @@ import rebel
|
|
24
24
|
import torch
|
25
25
|
from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
|
26
26
|
|
27
|
-
from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig
|
27
|
+
from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
|
28
28
|
from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
|
29
29
|
from .utils.logging import get_logger
|
30
|
-
from .utils.runtime_utils import UnavailableRuntime
|
30
|
+
from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
|
31
31
|
from .utils.save_utils import maybe_load_preprocessors
|
32
32
|
from .utils.submodule import SubModulesMixin
|
33
33
|
|
@@ -47,40 +47,6 @@ class RBLNBaseModelConfig(RBLNModelConfig):
|
|
47
47
|
|
48
48
|
|
49
49
|
class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
50
|
-
"""
|
51
|
-
An abstract base class for compiling, loading, and saving neural network models from the huggingface
|
52
|
-
transformers and diffusers libraries to run on RBLN NPU devices.
|
53
|
-
|
54
|
-
This class supports loading and saving models using the `from_pretrained` and `save_pretrained` methods,
|
55
|
-
similar to the huggingface libraries.
|
56
|
-
|
57
|
-
The `from_pretrained` method loads a model corresponding to the given `model_id` from a local repository
|
58
|
-
or the huggingface hub onto the NPU. If the model is a PyTorch model and `export=True` is passed as a
|
59
|
-
kwarg, it compiles the PyTorch model corresponding to the given `model_id` before loading. If `model_id`
|
60
|
-
is an already rbln-compiled model, it can be directly loaded onto the NPU with `export=False`.
|
61
|
-
|
62
|
-
`rbln_npu` is a kwarg required for compilation, specifying the name of the NPU to be used. If this
|
63
|
-
keyword is not specified, the NPU installed on the host machine is used. If no NPU is installed on the
|
64
|
-
host machine, an error occurs.
|
65
|
-
|
66
|
-
`rbln_device` specifies the device to be used at runtime. If not specified, device 0 is used.
|
67
|
-
|
68
|
-
`rbln_create_runtimes` indicates whether to create runtime objects. If False, the runtime does not load
|
69
|
-
the model onto the NPU. This option is particularly useful when you want to perform compilation only on a
|
70
|
-
host machine without an NPU.
|
71
|
-
|
72
|
-
`RBLNModel`, `RBLNModelFor*`, etc. are all child classes of RBLNBaseModel.
|
73
|
-
|
74
|
-
Models compiled in this way can be saved to a local repository using `save_pretrained` or uploaded to
|
75
|
-
the huggingface hub.
|
76
|
-
|
77
|
-
It also supports generation through `generate` (for transformers models that support generation).
|
78
|
-
|
79
|
-
RBLNBaseModel is a class for models consisting of an arbitrary number of `torch.nn.Module`s, and
|
80
|
-
therefore is an abstract class without explicit implementations of `forward` or `export` functions.
|
81
|
-
To inherit from this class, `forward`, `export`, etc. must be implemented.
|
82
|
-
"""
|
83
|
-
|
84
50
|
model_type = "rbln_model"
|
85
51
|
auto_model_class = AutoModel
|
86
52
|
config_class = AutoConfig
|
@@ -156,7 +122,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
156
122
|
subfolder: str = "",
|
157
123
|
local_files_only: bool = False,
|
158
124
|
) -> str:
|
159
|
-
|
125
|
+
# Load the directory containing the compiled model files.
|
160
126
|
model_path = Path(model_id)
|
161
127
|
|
162
128
|
if model_path.is_dir():
|
@@ -372,22 +338,59 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
372
338
|
def prepare_rbln_config(
|
373
339
|
cls, rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
|
374
340
|
) -> Tuple[RBLNModelConfig, Dict[str, Any]]:
|
375
|
-
|
376
|
-
|
377
|
-
"""
|
341
|
+
# Extract rbln-config from kwargs and convert it to RBLNModelConfig.
|
342
|
+
|
378
343
|
config_cls = cls.get_rbln_config_class()
|
379
344
|
rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
|
380
345
|
return rbln_config, kwargs
|
381
346
|
|
382
347
|
@classmethod
|
383
|
-
def from_pretrained(
|
348
|
+
def from_pretrained(
|
349
|
+
cls: Type["RBLNBaseModel"],
|
350
|
+
model_id: Union[str, Path],
|
351
|
+
export: bool = False,
|
352
|
+
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
353
|
+
**kwargs: Dict[str, Any],
|
354
|
+
) -> "RBLNBaseModel":
|
355
|
+
"""
|
356
|
+
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
357
|
+
User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
|
358
|
+
|
359
|
+
Args:
|
360
|
+
model_id: The model id of the pre-trained model to be loaded. It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
|
361
|
+
export: A boolean flag to indicate whether the model should be compiled.
|
362
|
+
rbln_config: Configuration for RBLN model compilation and runtime. This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
|
363
|
+
For detailed configuration options, see the specific model's configuration class documentation.
|
364
|
+
|
365
|
+
kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
366
|
+
|
367
|
+
Returns:
|
368
|
+
A RBLN model instance ready for inference on RBLN NPU devices.
|
369
|
+
"""
|
370
|
+
|
384
371
|
if isinstance(model_id, Path):
|
385
372
|
model_id = model_id.as_posix()
|
386
373
|
from_pretrained_method = cls._export if export else cls._from_pretrained
|
387
|
-
return from_pretrained_method(model_id=model_id, **kwargs)
|
374
|
+
return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
|
388
375
|
|
389
376
|
@classmethod
|
390
|
-
def compile(
|
377
|
+
def compile(
|
378
|
+
cls,
|
379
|
+
model,
|
380
|
+
rbln_compile_config: RBLNCompileConfig,
|
381
|
+
create_runtimes: bool,
|
382
|
+
device: Union[int, List[int]],
|
383
|
+
**kwargs,
|
384
|
+
):
|
385
|
+
if create_runtimes:
|
386
|
+
runtime_cannot_be_created = tp_and_devices_are_ok(
|
387
|
+
tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
|
388
|
+
device=device,
|
389
|
+
npu=rbln_compile_config.npu,
|
390
|
+
)
|
391
|
+
if runtime_cannot_be_created:
|
392
|
+
raise ValueError(runtime_cannot_be_created)
|
393
|
+
|
391
394
|
compiled_model = rebel.compile_from_torch(
|
392
395
|
model,
|
393
396
|
input_info=rbln_compile_config.input_info,
|
@@ -411,15 +414,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
411
414
|
|
412
415
|
@classmethod
|
413
416
|
def get_hf_class(cls):
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
the transformers/diffusers module.
|
417
|
+
# Lazily loads and caches the corresponding HuggingFace model class.
|
418
|
+
# Removes 'RBLN' prefix from the class name to get the original class name
|
419
|
+
# (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
|
420
|
+
# the transformers/diffusers module.
|
419
421
|
|
420
|
-
Returns:
|
421
|
-
|
422
|
-
"""
|
422
|
+
# Returns:
|
423
|
+
# type: The original HuggingFace model class
|
423
424
|
if cls._hf_class is None:
|
424
425
|
hf_cls_name = cls.__name__[4:]
|
425
426
|
library = importlib.import_module(cls.hf_library_name)
|
@@ -428,18 +429,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
428
429
|
|
429
430
|
@classmethod
|
430
431
|
def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
|
431
|
-
|
432
|
-
Lazily loads and caches the corresponding RBLN model config class.
|
433
|
-
"""
|
432
|
+
# Lazily loads and caches the corresponding RBLN model config class.
|
434
433
|
if cls._rbln_config_class is None:
|
435
434
|
rbln_config_class_name = cls.__name__ + "Config"
|
436
|
-
|
437
|
-
cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
|
438
|
-
if cls._rbln_config_class is None:
|
439
|
-
raise ValueError(
|
440
|
-
f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
|
441
|
-
"Please report it to the developers."
|
442
|
-
)
|
435
|
+
cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
|
443
436
|
return cls._rbln_config_class
|
444
437
|
|
445
438
|
def can_generate(self):
|
@@ -449,17 +442,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
449
442
|
return self
|
450
443
|
|
451
444
|
def parameters(self):
|
452
|
-
|
453
|
-
Provides a dummy parameter generator for compatibility.
|
445
|
+
# A dummy parameter generator for compatibility.
|
454
446
|
|
455
|
-
This method mimics the interface of torch.nn.Module.parameters()
|
456
|
-
specifically for code that uses `next(model.parameters())` to infer
|
457
|
-
the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
|
447
|
+
# This method mimics the interface of torch.nn.Module.parameters()
|
448
|
+
# specifically for code that uses `next(model.parameters())` to infer
|
449
|
+
# the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
|
458
450
|
|
459
|
-
Warning:
|
460
|
-
|
461
|
-
|
462
|
-
"""
|
451
|
+
# Warning:
|
452
|
+
# This does NOT yield the actual model parameters used by the RBLN runtime.
|
453
|
+
# Code relying on iterating through all model parameters will not work as expected.
|
463
454
|
yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
|
464
455
|
|
465
456
|
def __call__(self, *args, **kwargs):
|
@@ -547,7 +538,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
547
538
|
|
548
539
|
@staticmethod
|
549
540
|
def _raise_missing_compiled_file_error(missing_files: List[str]):
|
550
|
-
|
541
|
+
# Raises a KeyError with a message indicating missing compiled model files.
|
551
542
|
|
552
543
|
if len(missing_files) == 1:
|
553
544
|
message = f"The rbln model folder is missing the required '{missing_files[0]}.rbln' file. "
|
@@ -563,40 +554,3 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
563
554
|
"and ensure the compilation completes successfully."
|
564
555
|
)
|
565
556
|
raise KeyError(message)
|
566
|
-
|
567
|
-
@classmethod
|
568
|
-
@abstractmethod
|
569
|
-
def _update_rbln_config(cls, **rbln_config_kwargs) -> RBLNModelConfig:
|
570
|
-
pass
|
571
|
-
|
572
|
-
@classmethod
|
573
|
-
@abstractmethod
|
574
|
-
def _create_runtimes(
|
575
|
-
cls,
|
576
|
-
compiled_models: List[rebel.RBLNCompiledModel],
|
577
|
-
rbln_config: RBLNModelConfig,
|
578
|
-
) -> List[rebel.Runtime]:
|
579
|
-
# compiled_models -> runtimes
|
580
|
-
pass
|
581
|
-
|
582
|
-
@classmethod
|
583
|
-
@abstractmethod
|
584
|
-
def get_pytorch_model(cls, *args, **kwargs):
|
585
|
-
pass
|
586
|
-
|
587
|
-
@classmethod
|
588
|
-
@abstractmethod
|
589
|
-
def from_model(
|
590
|
-
cls,
|
591
|
-
model: "PreTrainedModel",
|
592
|
-
config: Optional[PretrainedConfig] = None,
|
593
|
-
rbln_config: Optional[RBLNModelConfig] = None,
|
594
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
595
|
-
subfolder: str = "",
|
596
|
-
**kwargs,
|
597
|
-
):
|
598
|
-
pass
|
599
|
-
|
600
|
-
@abstractmethod
|
601
|
-
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
602
|
-
pass
|
@@ -18,16 +18,9 @@ from transformers.utils import _LazyModule
|
|
18
18
|
|
19
19
|
|
20
20
|
_import_structure = {
|
21
|
-
"configuration_alias": [
|
22
|
-
"RBLNASTForAudioClassificationConfig",
|
23
|
-
"RBLNDistilBertForQuestionAnsweringConfig",
|
24
|
-
"RBLNResNetForImageClassificationConfig",
|
25
|
-
"RBLNXLMRobertaForSequenceClassificationConfig",
|
26
|
-
"RBLNRobertaForSequenceClassificationConfig",
|
27
|
-
"RBLNRobertaForMaskedLMConfig",
|
28
|
-
"RBLNViTForImageClassificationConfig",
|
29
|
-
],
|
30
21
|
"models": [
|
22
|
+
"RBLNASTForAudioClassification",
|
23
|
+
"RBLNASTForAudioClassificationConfig",
|
31
24
|
"RBLNAutoModel",
|
32
25
|
"RBLNAutoModelForAudioClassification",
|
33
26
|
"RBLNAutoModelForCausalLM",
|
@@ -51,12 +44,14 @@ _import_structure = {
|
|
51
44
|
"RBLNBertForQuestionAnsweringConfig",
|
52
45
|
"RBLNBertModel",
|
53
46
|
"RBLNBertModelConfig",
|
54
|
-
"RBLNBlip2VisionModelConfig",
|
55
|
-
"RBLNBlip2VisionModel",
|
56
|
-
"RBLNBlip2QFormerModel",
|
57
|
-
"RBLNBlip2QFormerModelConfig",
|
58
47
|
"RBLNBlip2ForConditionalGeneration",
|
59
48
|
"RBLNBlip2ForConditionalGenerationConfig",
|
49
|
+
"RBLNBlip2QFormerModel",
|
50
|
+
"RBLNBlip2QFormerModelConfig",
|
51
|
+
"RBLNBlip2VisionModel",
|
52
|
+
"RBLNBlip2VisionModelConfig",
|
53
|
+
"RBLNColPaliForRetrieval",
|
54
|
+
"RBLNColPaliForRetrievalConfig",
|
60
55
|
"RBLNCLIPTextModel",
|
61
56
|
"RBLNCLIPTextModelConfig",
|
62
57
|
"RBLNCLIPTextModelWithProjection",
|
@@ -67,40 +62,48 @@ _import_structure = {
|
|
67
62
|
"RBLNCLIPVisionModelWithProjectionConfig",
|
68
63
|
"RBLNDecoderOnlyModelForCausalLM",
|
69
64
|
"RBLNDecoderOnlyModelForCausalLMConfig",
|
65
|
+
"RBLNDistilBertForQuestionAnswering",
|
66
|
+
"RBLNDistilBertForQuestionAnsweringConfig",
|
70
67
|
"RBLNDPTForDepthEstimation",
|
71
68
|
"RBLNDPTForDepthEstimationConfig",
|
72
69
|
"RBLNExaoneForCausalLM",
|
73
70
|
"RBLNExaoneForCausalLMConfig",
|
74
|
-
"RBLNGemmaForCausalLM",
|
75
|
-
"RBLNGemmaForCausalLMConfig",
|
76
71
|
"RBLNGemma3ForCausalLM",
|
77
72
|
"RBLNGemma3ForCausalLMConfig",
|
78
73
|
"RBLNGemma3ForConditionalGeneration",
|
79
74
|
"RBLNGemma3ForConditionalGenerationConfig",
|
75
|
+
"RBLNGemmaForCausalLM",
|
76
|
+
"RBLNGemmaForCausalLMConfig",
|
80
77
|
"RBLNGPT2LMHeadModel",
|
81
78
|
"RBLNGPT2LMHeadModelConfig",
|
82
|
-
"RBLNIdefics3VisionTransformer",
|
83
79
|
"RBLNIdefics3ForConditionalGeneration",
|
84
80
|
"RBLNIdefics3ForConditionalGenerationConfig",
|
81
|
+
"RBLNIdefics3VisionTransformer",
|
85
82
|
"RBLNIdefics3VisionTransformerConfig",
|
86
83
|
"RBLNLlamaForCausalLM",
|
87
84
|
"RBLNLlamaForCausalLMConfig",
|
88
|
-
"RBLNOPTForCausalLM",
|
89
|
-
"RBLNOPTForCausalLMConfig",
|
90
85
|
"RBLNLlavaNextForConditionalGeneration",
|
91
86
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
92
87
|
"RBLNMidmLMHeadModel",
|
93
88
|
"RBLNMidmLMHeadModelConfig",
|
94
89
|
"RBLNMistralForCausalLM",
|
95
90
|
"RBLNMistralForCausalLMConfig",
|
91
|
+
"RBLNOPTForCausalLM",
|
92
|
+
"RBLNOPTForCausalLMConfig",
|
96
93
|
"RBLNPhiForCausalLM",
|
97
94
|
"RBLNPhiForCausalLMConfig",
|
98
|
-
"RBLNQwen2ForCausalLM",
|
99
|
-
"RBLNQwen2ForCausalLMConfig",
|
100
95
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
101
96
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
102
97
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
103
98
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
99
|
+
"RBLNQwen2ForCausalLM",
|
100
|
+
"RBLNQwen2ForCausalLMConfig",
|
101
|
+
"RBLNResNetForImageClassification",
|
102
|
+
"RBLNResNetForImageClassificationConfig",
|
103
|
+
"RBLNRobertaForMaskedLM",
|
104
|
+
"RBLNRobertaForMaskedLMConfig",
|
105
|
+
"RBLNRobertaForSequenceClassification",
|
106
|
+
"RBLNRobertaForSequenceClassificationConfig",
|
104
107
|
"RBLNSiglipVisionModel",
|
105
108
|
"RBLNSiglipVisionModelConfig",
|
106
109
|
"RBLNT5EncoderModel",
|
@@ -109,44 +112,23 @@ _import_structure = {
|
|
109
112
|
"RBLNT5ForConditionalGenerationConfig",
|
110
113
|
"RBLNTimeSeriesTransformerForPrediction",
|
111
114
|
"RBLNTimeSeriesTransformerForPredictionConfig",
|
115
|
+
"RBLNViTForImageClassification",
|
116
|
+
"RBLNViTForImageClassificationConfig",
|
112
117
|
"RBLNWav2Vec2ForCTC",
|
113
118
|
"RBLNWav2Vec2ForCTCConfig",
|
114
119
|
"RBLNWhisperForConditionalGeneration",
|
115
120
|
"RBLNWhisperForConditionalGenerationConfig",
|
121
|
+
"RBLNXLMRobertaForSequenceClassification",
|
122
|
+
"RBLNXLMRobertaForSequenceClassificationConfig",
|
116
123
|
"RBLNXLMRobertaModel",
|
117
124
|
"RBLNXLMRobertaModelConfig",
|
118
125
|
],
|
119
|
-
"modeling_alias": [
|
120
|
-
"RBLNASTForAudioClassification",
|
121
|
-
"RBLNDistilBertForQuestionAnswering",
|
122
|
-
"RBLNResNetForImageClassification",
|
123
|
-
"RBLNXLMRobertaForSequenceClassification",
|
124
|
-
"RBLNRobertaForSequenceClassification",
|
125
|
-
"RBLNRobertaForMaskedLM",
|
126
|
-
"RBLNViTForImageClassification",
|
127
|
-
],
|
128
126
|
}
|
129
127
|
|
130
128
|
if TYPE_CHECKING:
|
131
|
-
from .configuration_alias import (
|
132
|
-
RBLNASTForAudioClassificationConfig,
|
133
|
-
RBLNDistilBertForQuestionAnsweringConfig,
|
134
|
-
RBLNResNetForImageClassificationConfig,
|
135
|
-
RBLNRobertaForMaskedLMConfig,
|
136
|
-
RBLNRobertaForSequenceClassificationConfig,
|
137
|
-
RBLNViTForImageClassificationConfig,
|
138
|
-
RBLNXLMRobertaForSequenceClassificationConfig,
|
139
|
-
)
|
140
|
-
from .modeling_alias import (
|
141
|
-
RBLNASTForAudioClassification,
|
142
|
-
RBLNDistilBertForQuestionAnswering,
|
143
|
-
RBLNResNetForImageClassification,
|
144
|
-
RBLNRobertaForMaskedLM,
|
145
|
-
RBLNRobertaForSequenceClassification,
|
146
|
-
RBLNViTForImageClassification,
|
147
|
-
RBLNXLMRobertaForSequenceClassification,
|
148
|
-
)
|
149
129
|
from .models import (
|
130
|
+
RBLNASTForAudioClassification,
|
131
|
+
RBLNASTForAudioClassificationConfig,
|
150
132
|
RBLNAutoModel,
|
151
133
|
RBLNAutoModelForAudioClassification,
|
152
134
|
RBLNAutoModelForCausalLM,
|
@@ -186,6 +168,8 @@ if TYPE_CHECKING:
|
|
186
168
|
RBLNCLIPVisionModelWithProjectionConfig,
|
187
169
|
RBLNDecoderOnlyModelForCausalLM,
|
188
170
|
RBLNDecoderOnlyModelForCausalLMConfig,
|
171
|
+
RBLNDistilBertForQuestionAnswering,
|
172
|
+
RBLNDistilBertForQuestionAnsweringConfig,
|
189
173
|
RBLNDPTForDepthEstimation,
|
190
174
|
RBLNDPTForDepthEstimationConfig,
|
191
175
|
RBLNExaoneForCausalLM,
|
@@ -220,6 +204,12 @@ if TYPE_CHECKING:
|
|
220
204
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
221
205
|
RBLNQwen2ForCausalLM,
|
222
206
|
RBLNQwen2ForCausalLMConfig,
|
207
|
+
RBLNResNetForImageClassification,
|
208
|
+
RBLNResNetForImageClassificationConfig,
|
209
|
+
RBLNRobertaForMaskedLM,
|
210
|
+
RBLNRobertaForMaskedLMConfig,
|
211
|
+
RBLNRobertaForSequenceClassification,
|
212
|
+
RBLNRobertaForSequenceClassificationConfig,
|
223
213
|
RBLNSiglipVisionModel,
|
224
214
|
RBLNSiglipVisionModelConfig,
|
225
215
|
RBLNT5EncoderModel,
|
@@ -228,10 +218,14 @@ if TYPE_CHECKING:
|
|
228
218
|
RBLNT5ForConditionalGenerationConfig,
|
229
219
|
RBLNTimeSeriesTransformerForPrediction,
|
230
220
|
RBLNTimeSeriesTransformerForPredictionConfig,
|
221
|
+
RBLNViTForImageClassification,
|
222
|
+
RBLNViTForImageClassificationConfig,
|
231
223
|
RBLNWav2Vec2ForCTC,
|
232
224
|
RBLNWav2Vec2ForCTCConfig,
|
233
225
|
RBLNWhisperForConditionalGeneration,
|
234
226
|
RBLNWhisperForConditionalGenerationConfig,
|
227
|
+
RBLNXLMRobertaForSequenceClassification,
|
228
|
+
RBLNXLMRobertaForSequenceClassificationConfig,
|
235
229
|
RBLNXLMRobertaModel,
|
236
230
|
RBLNXLMRobertaModelConfig,
|
237
231
|
)
|
@@ -12,12 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import List, Optional, Tuple, Union
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16
16
|
|
17
17
|
from ..configuration_utils import RBLNModelConfig
|
18
18
|
|
19
19
|
|
20
|
-
class
|
20
|
+
class RBLNTransformerEncoderConfig(RBLNModelConfig):
|
21
21
|
rbln_model_input_names: Optional[List[str]] = None
|
22
22
|
|
23
23
|
def __init__(
|
@@ -25,7 +25,7 @@ class _RBLNTransformerEncoderConfig(RBLNModelConfig):
|
|
25
25
|
max_seq_len: Optional[int] = None,
|
26
26
|
batch_size: Optional[int] = None,
|
27
27
|
model_input_names: Optional[List[str]] = None,
|
28
|
-
**kwargs,
|
28
|
+
**kwargs: Dict[str, Any],
|
29
29
|
):
|
30
30
|
"""
|
31
31
|
Args:
|
@@ -47,9 +47,12 @@ class _RBLNTransformerEncoderConfig(RBLNModelConfig):
|
|
47
47
|
self.model_input_names = model_input_names or self.rbln_model_input_names
|
48
48
|
|
49
49
|
|
50
|
-
class
|
50
|
+
class RBLNImageModelConfig(RBLNModelConfig):
|
51
51
|
def __init__(
|
52
|
-
self,
|
52
|
+
self,
|
53
|
+
image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
54
|
+
batch_size: Optional[int] = None,
|
55
|
+
**kwargs: Dict[str, Any],
|
53
56
|
):
|
54
57
|
"""
|
55
58
|
Args:
|
@@ -86,32 +89,32 @@ class _RBLNImageModelConfig(RBLNModelConfig):
|
|
86
89
|
return self.image_size["height"]
|
87
90
|
|
88
91
|
|
89
|
-
class RBLNModelForQuestionAnsweringConfig(
|
92
|
+
class RBLNModelForQuestionAnsweringConfig(RBLNTransformerEncoderConfig):
|
90
93
|
pass
|
91
94
|
|
92
95
|
|
93
|
-
class RBLNModelForSequenceClassificationConfig(
|
96
|
+
class RBLNModelForSequenceClassificationConfig(RBLNTransformerEncoderConfig):
|
94
97
|
pass
|
95
98
|
|
96
99
|
|
97
|
-
class RBLNModelForMaskedLMConfig(
|
100
|
+
class RBLNModelForMaskedLMConfig(RBLNTransformerEncoderConfig):
|
98
101
|
pass
|
99
102
|
|
100
103
|
|
101
|
-
class RBLNModelForTextEncodingConfig(
|
104
|
+
class RBLNModelForTextEncodingConfig(RBLNTransformerEncoderConfig):
|
102
105
|
pass
|
103
106
|
|
104
107
|
|
105
108
|
# FIXME : Appropriate name ?
|
106
|
-
class RBLNTransformerEncoderForFeatureExtractionConfig(
|
109
|
+
class RBLNTransformerEncoderForFeatureExtractionConfig(RBLNTransformerEncoderConfig):
|
107
110
|
pass
|
108
111
|
|
109
112
|
|
110
|
-
class RBLNModelForImageClassificationConfig(
|
113
|
+
class RBLNModelForImageClassificationConfig(RBLNImageModelConfig):
|
111
114
|
pass
|
112
115
|
|
113
116
|
|
114
|
-
class RBLNModelForDepthEstimationConfig(
|
117
|
+
class RBLNModelForDepthEstimationConfig(RBLNImageModelConfig):
|
115
118
|
pass
|
116
119
|
|
117
120
|
|
@@ -121,7 +124,7 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
|
|
121
124
|
batch_size: Optional[int] = None,
|
122
125
|
max_length: Optional[int] = None,
|
123
126
|
num_mel_bins: Optional[int] = None,
|
124
|
-
**kwargs,
|
127
|
+
**kwargs: Dict[str, Any],
|
125
128
|
):
|
126
129
|
"""
|
127
130
|
Args:
|
@@ -43,9 +43,9 @@ from ..configuration_utils import RBLNCompileConfig
|
|
43
43
|
from ..modeling import RBLNModel
|
44
44
|
from ..utils.logging import get_logger
|
45
45
|
from .configuration_generic import (
|
46
|
+
RBLNImageModelConfig,
|
46
47
|
RBLNModelForAudioClassificationConfig,
|
47
|
-
|
48
|
-
_RBLNTransformerEncoderConfig,
|
48
|
+
RBLNTransformerEncoderConfig,
|
49
49
|
)
|
50
50
|
|
51
51
|
|
@@ -55,7 +55,7 @@ if TYPE_CHECKING:
|
|
55
55
|
logger = get_logger()
|
56
56
|
|
57
57
|
|
58
|
-
class
|
58
|
+
class RBLNTransformerEncoder(RBLNModel):
|
59
59
|
auto_model_class = AutoModel
|
60
60
|
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
61
61
|
rbln_dtype = "int64"
|
@@ -66,8 +66,8 @@ class _RBLNTransformerEncoder(RBLNModel):
|
|
66
66
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
67
67
|
model: Optional["PreTrainedModel"] = None,
|
68
68
|
model_config: Optional["PretrainedConfig"] = None,
|
69
|
-
rbln_config: Optional[
|
70
|
-
) ->
|
69
|
+
rbln_config: Optional[RBLNTransformerEncoderConfig] = None,
|
70
|
+
) -> RBLNTransformerEncoderConfig:
|
71
71
|
return cls.update_rbln_config_for_transformers_encoder(
|
72
72
|
preprocessors=preprocessors,
|
73
73
|
model=model,
|
@@ -81,8 +81,8 @@ class _RBLNTransformerEncoder(RBLNModel):
|
|
81
81
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
82
82
|
model: Optional["PreTrainedModel"] = None,
|
83
83
|
model_config: Optional["PretrainedConfig"] = None,
|
84
|
-
rbln_config: Optional[
|
85
|
-
) ->
|
84
|
+
rbln_config: Optional[RBLNTransformerEncoderConfig] = None,
|
85
|
+
) -> RBLNTransformerEncoderConfig:
|
86
86
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
87
87
|
model_config, "max_position_embeddings", None
|
88
88
|
)
|
@@ -139,7 +139,7 @@ class _RBLNTransformerEncoder(RBLNModel):
|
|
139
139
|
return rbln_config
|
140
140
|
|
141
141
|
|
142
|
-
class
|
142
|
+
class RBLNImageModel(RBLNModel):
|
143
143
|
auto_model_class = AutoModel
|
144
144
|
main_input_name = "pixel_values"
|
145
145
|
output_class = BaseModelOutput
|
@@ -150,8 +150,8 @@ class _RBLNImageModel(RBLNModel):
|
|
150
150
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
151
151
|
model: Optional["PreTrainedModel"] = None,
|
152
152
|
model_config: Optional["PretrainedConfig"] = None,
|
153
|
-
rbln_config: Optional[
|
154
|
-
) ->
|
153
|
+
rbln_config: Optional[RBLNImageModelConfig] = None,
|
154
|
+
) -> RBLNImageModelConfig:
|
155
155
|
return cls.update_rbln_config_for_image_model(
|
156
156
|
preprocessors=preprocessors,
|
157
157
|
model=model,
|
@@ -165,8 +165,8 @@ class _RBLNImageModel(RBLNModel):
|
|
165
165
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
166
166
|
model: Optional["PreTrainedModel"] = None,
|
167
167
|
model_config: Optional["PretrainedConfig"] = None,
|
168
|
-
rbln_config: Optional[
|
169
|
-
) ->
|
168
|
+
rbln_config: Optional[RBLNImageModelConfig] = None,
|
169
|
+
) -> RBLNImageModelConfig:
|
170
170
|
if rbln_config.image_size is None:
|
171
171
|
for processor in preprocessors:
|
172
172
|
if hasattr(processor, "size"):
|
@@ -196,15 +196,14 @@ class _RBLNImageModel(RBLNModel):
|
|
196
196
|
return rbln_config
|
197
197
|
|
198
198
|
|
199
|
-
class RBLNModelForQuestionAnswering(
|
199
|
+
class RBLNModelForQuestionAnswering(RBLNTransformerEncoder):
|
200
200
|
auto_model_class = AutoModelForQuestionAnswering
|
201
201
|
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
202
202
|
output_class = QuestionAnsweringModelOutput
|
203
203
|
|
204
204
|
def _prepare_output(self, output, return_dict):
|
205
|
-
|
206
|
-
|
207
|
-
"""
|
205
|
+
# Prepare QuestionAnswering specific output format.
|
206
|
+
|
208
207
|
start_logits, end_logits = output
|
209
208
|
|
210
209
|
if not return_dict:
|
@@ -213,32 +212,32 @@ class RBLNModelForQuestionAnswering(_RBLNTransformerEncoder):
|
|
213
212
|
return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits)
|
214
213
|
|
215
214
|
|
216
|
-
class RBLNModelForSequenceClassification(
|
215
|
+
class RBLNModelForSequenceClassification(RBLNTransformerEncoder):
|
217
216
|
auto_model_class = AutoModelForSequenceClassification
|
218
217
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
219
218
|
|
220
219
|
|
221
|
-
class RBLNModelForMaskedLM(
|
220
|
+
class RBLNModelForMaskedLM(RBLNTransformerEncoder):
|
222
221
|
auto_model_class = AutoModelForMaskedLM
|
223
222
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
224
223
|
|
225
224
|
|
226
|
-
class RBLNModelForTextEncoding(
|
225
|
+
class RBLNModelForTextEncoding(RBLNTransformerEncoder):
|
227
226
|
auto_model_class = AutoModelForTextEncoding
|
228
227
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
229
228
|
|
230
229
|
|
231
|
-
class RBLNTransformerEncoderForFeatureExtraction(
|
230
|
+
class RBLNTransformerEncoderForFeatureExtraction(RBLNTransformerEncoder):
|
232
231
|
# TODO: RBLNModel is also for feature extraction.
|
233
232
|
auto_model_class = AutoModel
|
234
233
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
235
234
|
|
236
235
|
|
237
|
-
class RBLNModelForImageClassification(
|
236
|
+
class RBLNModelForImageClassification(RBLNImageModel):
|
238
237
|
auto_model_class = AutoModelForImageClassification
|
239
238
|
|
240
239
|
|
241
|
-
class RBLNModelForDepthEstimation(
|
240
|
+
class RBLNModelForDepthEstimation(RBLNImageModel):
|
242
241
|
auto_model_class = AutoModelForDepthEstimation
|
243
242
|
|
244
243
|
|
@@ -48,10 +48,13 @@ def _compute_default_rope_parameters(
|
|
48
48
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
49
49
|
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
50
50
|
"""
|
51
|
-
|
52
51
|
base = config.rope_theta
|
53
52
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
54
|
-
head_dim =
|
53
|
+
head_dim = (
|
54
|
+
config.head_dim
|
55
|
+
if hasattr(config, "head_dim") and config.head_dim is not None
|
56
|
+
else config.hidden_size // config.num_attention_heads
|
57
|
+
)
|
55
58
|
dim = int(head_dim * partial_rotary_factor)
|
56
59
|
|
57
60
|
attention_factor = 1.0 # Unused in this type of RoPE
|