optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
|
25
|
+
# Prefix used for RBLN model class names
|
26
|
+
RBLN_PREFIX = "RBLN"
|
27
|
+
|
28
|
+
|
29
|
+
def convert_hf_to_rbln_model_name(hf_model_name: str):
|
30
|
+
"""
|
31
|
+
Convert Hugging Face model name to RBLN model name.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
hf_model_name (str): The Hugging Face model name.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
str: The corresponding RBLN model name.
|
38
|
+
"""
|
39
|
+
return RBLN_PREFIX + hf_model_name
|
40
|
+
|
41
|
+
|
42
|
+
def convert_rbln_to_hf_model_name(rbln_model_name: str):
|
43
|
+
"""
|
44
|
+
Convert RBLN model name to Hugging Face model name.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
rbln_model_name (str): The RBLN model name.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
str: The corresponding Hugging Face model name.
|
51
|
+
"""
|
52
|
+
|
53
|
+
return rbln_model_name.removeprefix(RBLN_PREFIX)
|
@@ -43,7 +43,7 @@ class RBLNPytorchRuntime:
|
|
43
43
|
return self.forward(*args, **kwds)
|
44
44
|
|
45
45
|
def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
|
46
|
-
# filtering
|
46
|
+
# filtering useless args or kwarg such as None.
|
47
47
|
args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
|
48
48
|
kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
|
49
49
|
output = self.runtime(*args, **kwargs)
|
@@ -67,7 +67,7 @@ class UnavailableRuntime:
|
|
67
67
|
return iter([self])
|
68
68
|
|
69
69
|
def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
|
70
|
-
raise RuntimeError("
|
70
|
+
raise RuntimeError("The model can't run because the runtime hasn't been created.")
|
71
71
|
|
72
72
|
def __repr__(self) -> str:
|
73
73
|
return "UnavailableRuntime"
|
@@ -76,17 +76,17 @@ class UnavailableRuntime:
|
|
76
76
|
class ContextRblnConfig:
|
77
77
|
_local = threading.local()
|
78
78
|
|
79
|
-
def __init__(self, device=None, device_map=None, create_runtimes=None,
|
79
|
+
def __init__(self, device=None, device_map=None, create_runtimes=None, optimize_host_mem=None):
|
80
80
|
self.device = device
|
81
81
|
self.device_map = device_map
|
82
82
|
self.create_runtimes = create_runtimes
|
83
|
-
self.
|
83
|
+
self.optimize_host_mem = optimize_host_mem
|
84
84
|
|
85
85
|
def __enter__(self):
|
86
86
|
self._local.device = self.device
|
87
87
|
self._local.device_map = self.device_map
|
88
88
|
self._local.create_runtimes = self.create_runtimes
|
89
|
-
self._local.optimize_host_memory = self.
|
89
|
+
self._local.optimize_host_memory = self.optimize_host_mem
|
90
90
|
return self
|
91
91
|
|
92
92
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
@@ -0,0 +1,114 @@
|
|
1
|
+
import importlib
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import TYPE_CHECKING, Any, Dict, List
|
4
|
+
|
5
|
+
from ..modeling_config import RBLNConfig
|
6
|
+
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from transformers import PreTrainedModel
|
10
|
+
|
11
|
+
from ..modeling_base import RBLNBaseModel
|
12
|
+
|
13
|
+
|
14
|
+
class SubModulesMixin:
|
15
|
+
"""
|
16
|
+
_rbln_submodules = [
|
17
|
+
{"name": "vision_tower"},
|
18
|
+
{"name": "language_model"},
|
19
|
+
]
|
20
|
+
"""
|
21
|
+
|
22
|
+
_rbln_submodules: List[Dict[str, Any]] = []
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
*,
|
27
|
+
rbln_submodules: List["RBLNBaseModel"] = [],
|
28
|
+
**kwargs,
|
29
|
+
) -> None:
|
30
|
+
for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
|
31
|
+
setattr(self, submodule_meta["name"], submodule)
|
32
|
+
|
33
|
+
@classmethod
|
34
|
+
def _export_submodules_from_model(
|
35
|
+
cls,
|
36
|
+
model: "PreTrainedModel",
|
37
|
+
model_save_dir: str,
|
38
|
+
rbln_kwargs: Dict[str, Any],
|
39
|
+
**kwargs,
|
40
|
+
) -> List["RBLNBaseModel"]:
|
41
|
+
rbln_submodules = []
|
42
|
+
for submodule in cls._rbln_submodules:
|
43
|
+
submodule_name = submodule["name"]
|
44
|
+
torch_submodule: "PreTrainedModel" = getattr(model, submodule["name"])
|
45
|
+
cls_name = torch_submodule.__class__.__name__
|
46
|
+
submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
|
47
|
+
|
48
|
+
if submodule_name in rbln_kwargs:
|
49
|
+
kwargs["rbln_config"] = rbln_kwargs[submodule_name]
|
50
|
+
|
51
|
+
rbln_submodule = submodule_cls.from_model(
|
52
|
+
model=torch_submodule,
|
53
|
+
subfolder=submodule_name,
|
54
|
+
model_save_dir=model_save_dir,
|
55
|
+
**kwargs,
|
56
|
+
)
|
57
|
+
|
58
|
+
rbln_submodules.append(rbln_submodule)
|
59
|
+
|
60
|
+
return rbln_submodules
|
61
|
+
|
62
|
+
@classmethod
|
63
|
+
def _load_submodules_from_compiled_models(
|
64
|
+
cls,
|
65
|
+
model_save_dir: str,
|
66
|
+
rbln_kwargs: Dict[str, Any],
|
67
|
+
**kwargs,
|
68
|
+
):
|
69
|
+
rbln_submodules = []
|
70
|
+
for submodule in cls._rbln_submodules:
|
71
|
+
submodule_name = submodule["name"]
|
72
|
+
|
73
|
+
if submodule_name in rbln_kwargs:
|
74
|
+
kwargs["rbln_config"] = rbln_kwargs[submodule_name]
|
75
|
+
|
76
|
+
# Get cls name for call the constructor of the rbln class
|
77
|
+
submodule_rbln_config = RBLNConfig.load(Path(model_save_dir) / submodule_name)
|
78
|
+
submodule_cls_name = submodule_rbln_config.meta["cls"]
|
79
|
+
submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), submodule_cls_name)
|
80
|
+
|
81
|
+
rbln_submodule = submodule_cls._from_pretrained(
|
82
|
+
model_id=model_save_dir,
|
83
|
+
config=None,
|
84
|
+
subfolder=submodule_name,
|
85
|
+
**kwargs,
|
86
|
+
)
|
87
|
+
rbln_submodules.append(rbln_submodule)
|
88
|
+
return rbln_submodules
|
89
|
+
|
90
|
+
@classmethod
|
91
|
+
def _load_submodules(
|
92
|
+
cls,
|
93
|
+
model_save_dir,
|
94
|
+
rbln_kwargs,
|
95
|
+
model=None,
|
96
|
+
**kwargs,
|
97
|
+
):
|
98
|
+
# Two ways :
|
99
|
+
# 1. Compile from pytorch object
|
100
|
+
# 2. Load from compiled file
|
101
|
+
if model is not None:
|
102
|
+
return cls._export_submodules_from_model(
|
103
|
+
model=model,
|
104
|
+
model_save_dir=model_save_dir,
|
105
|
+
rbln_kwargs=rbln_kwargs,
|
106
|
+
**kwargs,
|
107
|
+
)
|
108
|
+
|
109
|
+
else:
|
110
|
+
return cls._load_submodules_from_compiled_models(
|
111
|
+
model_save_dir=model_save_dir,
|
112
|
+
rbln_kwargs=rbln_kwargs,
|
113
|
+
**kwargs,
|
114
|
+
)
|
@@ -12,11 +12,11 @@ logger = get_logger()
|
|
12
12
|
def rbln_timer(print_name):
|
13
13
|
def decorator(function):
|
14
14
|
def wrapper(*args, **kwargs):
|
15
|
-
disable = os.getenv("OPTIMUM_RBLN_DISABLE_SPIN",
|
15
|
+
disable = os.getenv("OPTIMUM_RBLN_DISABLE_SPIN", "False").lower() in ("true", "1", "t")
|
16
16
|
if disable:
|
17
17
|
logger.info(f"{print_name} ...")
|
18
18
|
|
19
|
-
spinner = Halo(text=f"{print_name} ...", spinner=
|
19
|
+
spinner = Halo(text=f"{print_name} ...", spinner="dots", color="green", enabled=(not disable))
|
20
20
|
spinner.start()
|
21
21
|
|
22
22
|
# Start timer
|
@@ -0,0 +1,106 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: optimum-rbln
|
3
|
+
Version: 0.1.15
|
4
|
+
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
|
+
Project-URL: Homepage, https://rebellions.ai
|
6
|
+
Project-URL: Documentation, https://docs.rbln.ai
|
7
|
+
Author-email: "Rebellions Inc." <support@rebellions.ai>
|
8
|
+
License: Apache
|
9
|
+
Keywords: atom,diffusers,inference,rbln,rebel,transformers
|
10
|
+
Classifier: Development Status :: 2 - Pre-Alpha
|
11
|
+
Classifier: Intended Audience :: Developers
|
12
|
+
Classifier: Intended Audience :: Education
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
14
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
15
|
+
Classifier: Operating System :: POSIX :: Linux
|
16
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
17
|
+
Classifier: Programming Language :: Python :: 3.9
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
19
|
+
Classifier: Programming Language :: Python :: 3.11
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
22
|
+
Requires-Python: <3.13,>=3.9
|
23
|
+
Requires-Dist: accelerate>=1.0.1
|
24
|
+
Requires-Dist: diffusers<=0.31.0
|
25
|
+
Requires-Dist: einops>=0.8.0
|
26
|
+
Requires-Dist: halo>=0.0.31
|
27
|
+
Requires-Dist: packaging>=24.1
|
28
|
+
Requires-Dist: torch<=2.5.1
|
29
|
+
Requires-Dist: torchaudio<=2.5.1
|
30
|
+
Requires-Dist: torchvision<=0.20.1
|
31
|
+
Requires-Dist: transformers==4.45.2
|
32
|
+
Description-Content-Type: text/markdown
|
33
|
+
|
34
|
+
|
35
|
+
# Optimum RBLN
|
36
|
+
|
37
|
+
<div align="center">
|
38
|
+
|
39
|
+
<img src="assets/rbln_logo.png" width="60%"/>
|
40
|
+
|
41
|
+
[](https://badge.fury.io/py/optimum-rbln)
|
42
|
+
[](https://github.com/rebellions-sw/optimum-rbln/blob/main/LICENSE)
|
43
|
+
|
44
|
+
</div>
|
45
|
+
|
46
|
+
🤗 Optimum RBLN provides an interface between Hugging Face libraries ([Transformers](https://huggingface.co/docs/transformers), [Diffusers](https://huggingface.co/docs/diffusers/index)) and RBLN Accelerators, including [ATOM](https://rebellions.ai/rebellions-product/rbln-ca25/) and [REBEL](https://rebellions.ai/rebellions-product/rebel/).
|
47
|
+
|
48
|
+
This library enables seamless integration between the Hugging Face ecosystem and RBLN's NPU acceleration through a comprehensive toolkit for model loading and inference across single- and multi-Accelerator environments. While we maintain a list of [officially validated models and tasks](https://docs.rbln.ai/software/optimum/optimum_rbln.html), users can easily adapt other models and tasks with minimal modifications.
|
49
|
+
|
50
|
+
## Key Features
|
51
|
+
|
52
|
+
🚀 **High Performance Inference**
|
53
|
+
- Optimized model execution on RBLN NPUs through RBLN SDK compilation
|
54
|
+
- Support for both single-NPU and multi-NPU inference
|
55
|
+
- Integrated with RBLN Runtime for optimal performance
|
56
|
+
|
57
|
+
🔧 **Easy Integration**
|
58
|
+
- Seamless compatibility with Huggingface model hub
|
59
|
+
- Drop-in replacement for existing Huggingface pipelines
|
60
|
+
- Minimal code changes required for NPU acceleration
|
61
|
+
|
62
|
+
|
63
|
+
## Documentation
|
64
|
+
|
65
|
+
Check out [the documentation of Optimum RBLN](https://docs.rbln.ai/software/optimum/optimum_rbln.html) for more advanced usage.
|
66
|
+
|
67
|
+
## Getting Started
|
68
|
+
|
69
|
+
### Install from PyPI
|
70
|
+
|
71
|
+
To install the latest release of this package:
|
72
|
+
|
73
|
+
- Export environment variables to access to RBLN private PyPI.
|
74
|
+
```bash
|
75
|
+
export REBEL_PYPI_USERNAME=<username>
|
76
|
+
export REBEL_PYPI_PASSWORD=<password>
|
77
|
+
```
|
78
|
+
|
79
|
+
- Install optimum-rbln package:
|
80
|
+
```bash
|
81
|
+
pip install --index-url https://pypi.rebellions.in/simple optimum-rbln
|
82
|
+
```
|
83
|
+
|
84
|
+
### Install from source
|
85
|
+
|
86
|
+
#### Prerequisites
|
87
|
+
|
88
|
+
- Install [uv](https://docs.astral.sh/uv/) (refer [this link](https://docs.astral.sh/uv/getting-started/installation/) for detailed commands)
|
89
|
+
|
90
|
+
The below command installs optimum-rbln along with its dependencies.
|
91
|
+
|
92
|
+
```bash
|
93
|
+
git clone https://github.com/rebellions-sw/optimum-rbln.git
|
94
|
+
cd optimum-rbln
|
95
|
+
./scripts/uv-sync.sh
|
96
|
+
```
|
97
|
+
|
98
|
+
If you want to install local rebel-compiler as editable mode in uv environment,
|
99
|
+
```bash
|
100
|
+
uv pip install -e /path/to/rebel_compiler/python
|
101
|
+
```
|
102
|
+
|
103
|
+
### Need Help?
|
104
|
+
|
105
|
+
- Join our [Developer Community](https://discuss.rebellions.ai/)
|
106
|
+
- Contact maintainers at [support@rebellions.ai](mailto:support@rebellions.ai)
|
@@ -0,0 +1,110 @@
|
|
1
|
+
optimum/rbln/__init__.py,sha256=rjaGo_lPR8m4RwnTYuLTOL15KNRKXbD2EGn7j_STXIg,6895
|
2
|
+
optimum/rbln/__version__.py,sha256=ZKlmJ822TJ49YEqc2wCAMbrp81vFvzcFa9OTia84voM,413
|
3
|
+
optimum/rbln/modeling.py,sha256=GpTLugUsFx5qTjyENwR7263naVZrMugtoVvWFEaQLzQ,23788
|
4
|
+
optimum/rbln/modeling_alias.py,sha256=Z9vGv6ca82_mhbYclxIZ6e8jt-gf07g--k3ljdQvtGo,2128
|
5
|
+
optimum/rbln/modeling_base.py,sha256=TPcJ8JhFvWepIrmPuMQp_IKLWlTmvy2Wb99rhoz_YDk,19755
|
6
|
+
optimum/rbln/modeling_config.py,sha256=va58Gpbn3rolqKu9y2u3vYVT6kynBGpox_jod6cs-j0,10612
|
7
|
+
optimum/rbln/modeling_diffusers.py,sha256=VabNyhVN5s8M_fCx18SkR9hAfJqfXBZwz1m4Sl9Yihg,14138
|
8
|
+
optimum/rbln/diffusers/__init__.py,sha256=jad5hGtgXfP6ZZzYI4uBnb1Qbt6TwfEIJjFOtdNzCgc,3187
|
9
|
+
optimum/rbln/diffusers/models/__init__.py,sha256=CKgWCqCEPrAc-l5SxKcwu7TadkSGvqkpNqpwrXZVv90,1749
|
10
|
+
optimum/rbln/diffusers/models/controlnet.py,sha256=rIYshEXkBqAGh7cOpfu2quffVHNJj9SQ-ATsgQkre5o,10889
|
11
|
+
optimum/rbln/diffusers/models/autoencoders/__init__.py,sha256=yc1ABZG3xxzWPDGf0ADEeuSz3Nrq4ZP-CwddQ-VvWCU,1039
|
12
|
+
optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py,sha256=EoI8EQYYS0SA4wurM3I2shs5Z6MA-YEXohoxDg40LrE,8554
|
13
|
+
optimum/rbln/diffusers/models/autoencoders/vae.py,sha256=Ys802twDAnNIMVRp-lL3Rhb8Gh-jot0IFCLBm68FrK8,2889
|
14
|
+
optimum/rbln/diffusers/models/transformers/__init__.py,sha256=2bVk_6nuqcREOIVyR-4w5ksmdWJqyIV7Wxc5x0dqYO8,1048
|
15
|
+
optimum/rbln/diffusers/models/transformers/transformer_sd3.py,sha256=C_50IkKUYnpCbRlsTsX07PVI5RLRgnouQyISvZhlVOg,7684
|
16
|
+
optimum/rbln/diffusers/models/unets/__init__.py,sha256=-0PyRbBVBFujd7nBh0Z4NOe3RVOlAWyvWLU8r62dqdo,1049
|
17
|
+
optimum/rbln/diffusers/models/unets/unet_2d_condition.py,sha256=p-obj4tVxBkQMS8W1On8oFrLK-TXCo3Zksw1bOBXRPw,14467
|
18
|
+
optimum/rbln/diffusers/pipelines/__init__.py,sha256=M6UtFsGnGKYnuHkuJnfyR5WzajAgUidVUNGaqLBS6bM,2862
|
19
|
+
optimum/rbln/diffusers/pipelines/controlnet/__init__.py,sha256=k0govvSBxBUR5qpxUGxRMHuQCMX7hXHVZ4EqVRw1LWk,1377
|
20
|
+
optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py,sha256=OnOvqnNCK1WrnU7JH93GSmEdMRW8Z2__eorYaj-zHAw,4424
|
21
|
+
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py,sha256=xlfDdKxgIK1DvIQAbOipA00LAMni8f7z8urpeOAyEkE,34884
|
22
|
+
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py,sha256=C9DwjPiNSQanqNNWiOHJvw3yZtL8tR4YaQZwVkC02_o,33369
|
23
|
+
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py,sha256=QvoS7g3dfZ3b1Fq1taIfRBf9kNpoR0jsv-H97hXNQ4A,44466
|
24
|
+
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py,sha256=hHkNKnBsSGdHeFjf45zIpum0DE6JRcw_NEYFLYYmxSg,45779
|
25
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py,sha256=7lX6f9XiqROsArw1X0lGsW06H0TrWKwvM9Ye16vis9Y,1224
|
26
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py,sha256=33LQR11kryvQqfQ0Srl8O8QqRsfJnKYUZAQHHSruKTY,1362
|
27
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py,sha256=WJ4pRfmE0DIAQANXoytoeVyKBa750j-f5oKOHLUmoyY,1390
|
28
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py,sha256=DaXqP_2oik-L7LitAXtaiMqcOZixOui67BWSU_qoxu4,1397
|
29
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py,sha256=I1IbI4uo2ZqA2uLbhqW67wW7-8DwqeE-qPGc3BTL7dQ,1233
|
30
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py,sha256=Hb4cGwz422jWvFefTfrPMSt7__KEdlrq41OjiSpNtFo,1410
|
31
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py,sha256=hQUAMqQn62-aezWlwscyJu6QlXjXByJ-NXnrvYVIpRk,1445
|
32
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py,sha256=8PorhGURHd01eob_4pkZhH6azjqk77FS2xZFnBED5Yg,1445
|
33
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py,sha256=giP9dJokdRT4-a5wdJqA1qW8os5Zz0huCack8nlcyxs,246
|
34
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py,sha256=XeeZqH63No7NPIXKXgv8Vxq_I0-iql4jHJ0sIkbYSvw,1390
|
35
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py,sha256=o1WTLCFjsDZAtGZ84cykU0F4Qw_iMP0g7jaPw1xobP0,1418
|
36
|
+
optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py,sha256=aozAYG2mba2G9ITspDVuos-eWy0MoxlwnVtxGgtN7zk,1425
|
37
|
+
optimum/rbln/transformers/__init__.py,sha256=H__SYW4dhcFb02euqtfFZx212QZxkmKu4kgd2sBtVAs,3756
|
38
|
+
optimum/rbln/transformers/cache_utils.py,sha256=Ak6yJlzkXsu5jQ-kWIgO97GrsLpoCezpEgQoZnvjmec,3796
|
39
|
+
optimum/rbln/transformers/modeling_rope_utils.py,sha256=ob8haNW5f0tPq26xd4sTg-rMrBT9M2VDP7wxt-PkvYY,13087
|
40
|
+
optimum/rbln/transformers/models/__init__.py,sha256=gffOrFFYG3W8ypxpDiUotU-McvjhytffeuUzunjv4nQ,3971
|
41
|
+
optimum/rbln/transformers/models/auto/__init__.py,sha256=zMqaMIdGwuZJq4gLjRC-69M2mGUKrd0aRpmb4Rpm6-g,435
|
42
|
+
optimum/rbln/transformers/models/auto/auto_factory.py,sha256=JIFL404RVf6kAONhLeEz6z59tdahcUAyVSU8hdZZe0g,7421
|
43
|
+
optimum/rbln/transformers/models/auto/modeling_auto.py,sha256=DDx8ZUpWPtTr8ZNv7cO_dUgJtDLXjax7gdxwc-wkkgQ,4328
|
44
|
+
optimum/rbln/transformers/models/bart/__init__.py,sha256=-mrB4kmPpEIVk713yEIRtC57AZ7kZ23g4bsOKcvBFNE,1138
|
45
|
+
optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=W6XeumvuKv1i7p4YzKM0NgpO3NCnc1qwGXknZZrPlP0,21298
|
46
|
+
optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=xzZZf-yZdUkq4F271Wfd1l-Hnm4jjgf_yy6hjVohxbo,5144
|
47
|
+
optimum/rbln/transformers/models/bert/__init__.py,sha256=divBpVNrRAdNAPgnQkGiEZI4oJHCJtLuwdYpMbD3dMM,1034
|
48
|
+
optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=akbsBTsGTs7wrxPw120ryZspwYkmHAUrM4A8Kr3COw4,4111
|
49
|
+
optimum/rbln/transformers/models/clip/__init__.py,sha256=iXZfPPIztzMDOkY3fbEzx9dCkFKKtWYXCpLGfjEUeZE,1092
|
50
|
+
optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=DI_N-bQcA_Kj7NSkv9VPvV1zsN6IscctVczE_2_ZkVM,6089
|
51
|
+
optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=ozc0c3XBI3-5VHhGvZ0zcv6TD-kIXpDCqsAvdW3JSaY,1222
|
52
|
+
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=ratAjwMF3eiHTaXwPNqHqWzmkGEC5fG37xP4mGJOMI8,36833
|
53
|
+
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=esQp936mlHuAjb1A75jepyB3QVXsUR7a5WgVUhV7wJw,26132
|
54
|
+
optimum/rbln/transformers/models/dpt/__init__.py,sha256=R8OHDxOAYPjkk5t8osaPqRh85Pf1Cg1BtzqesqFRcTI,1045
|
55
|
+
optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=Re15veJwAz3NaUv2GfrHyGUblW0Jcd2KLv23GutFp64,3805
|
56
|
+
optimum/rbln/transformers/models/exaone/__init__.py,sha256=CuWNwnZrbd_cLU7jDDPqC0kZIqx1ii_BYyQ98CKDag0,1253
|
57
|
+
optimum/rbln/transformers/models/exaone/exaone_architecture.py,sha256=sf0GF35u2AtyadR6WLxaau_0-JKusSomtfh0ILQMras,3528
|
58
|
+
optimum/rbln/transformers/models/exaone/modeling_exaone.py,sha256=-0VjxpBOQLM3PAmrWCJXkTKQEap577bS_izn-gx8Ew8,2141
|
59
|
+
optimum/rbln/transformers/models/gemma/__init__.py,sha256=L1Qfr6dufWtoUZND_ofwxXPSxivEvPR8exy16a_nM2o,1042
|
60
|
+
optimum/rbln/transformers/models/gemma/gemma_architecture.py,sha256=4irFBUeL1rEnHl-r5be_doz4QaqDN3jYZMcN1aHMLYo,2317
|
61
|
+
optimum/rbln/transformers/models/gemma/modeling_gemma.py,sha256=floBHXgogv3iAPyjhjKzbjFHeH67L3FYUKs_wtRm-gM,1924
|
62
|
+
optimum/rbln/transformers/models/gpt2/__init__.py,sha256=jsOKYXUclG9G6cwUTUX4eeKqjCPfQUwev7TTFIMXS4Y,1040
|
63
|
+
optimum/rbln/transformers/models/gpt2/gpt2_architecture.py,sha256=w4dAVeubsNkGtcapknwyyQ5VevWPTWETg4M6Y_tZ9UI,3359
|
64
|
+
optimum/rbln/transformers/models/gpt2/modeling_gpt2.py,sha256=uz29eh8bLWxsm8pVHwvA-X8FThW2khSO-Rjysp3RoQk,1910
|
65
|
+
optimum/rbln/transformers/models/llama/__init__.py,sha256=5mX-MuKzVBj6WQeVxyPhtvFTv0jeZXAFfg4RZ2nVUh0,1042
|
66
|
+
optimum/rbln/transformers/models/llama/llama_architecture.py,sha256=j4mifSOaIk7wwV9fL9wQSt5kR3rpnvjtxd3VzhMNdgY,1123
|
67
|
+
optimum/rbln/transformers/models/llama/modeling_llama.py,sha256=D9efkBVUr7TaOkAkiN_qrtQC0AyzLK7cb7UbZpo4XwI,1924
|
68
|
+
optimum/rbln/transformers/models/llava_next/__init__.py,sha256=3vi2rmTeKBydGRFOtxELhxWixZggFMpGex6xqfMgi-I,1064
|
69
|
+
optimum/rbln/transformers/models/llava_next/modeling_llava_next.py,sha256=Rkpso3eQ1tHXpfLdRUayut4X3J9zsXjF2in4UVN1Yhs,25883
|
70
|
+
optimum/rbln/transformers/models/midm/__init__.py,sha256=_6kYchy47frGMZ8uoUspZ9IwrmCBQJ-8kVfXM7xOMew,1249
|
71
|
+
optimum/rbln/transformers/models/midm/midm_architecture.py,sha256=PwGwHaYCHmJ4DRvlieVK_IzsZgG79_n-Y8kQs5NwT0A,5790
|
72
|
+
optimum/rbln/transformers/models/midm/modeling_midm.py,sha256=NB1Ie2GN9Ilisd9CMIbENRiTtUV0pK3eTStbuz0yQBg,2129
|
73
|
+
optimum/rbln/transformers/models/mistral/__init__.py,sha256=XtuOmzBITjj-H1yctXobJjHF908x1Wlxr_p4hi06v8I,1046
|
74
|
+
optimum/rbln/transformers/models/mistral/mistral_architecture.py,sha256=LCvY4L0Wq1VruKhZ3JTSiuZJqQRJlTae5A2bKsUBGAg,1128
|
75
|
+
optimum/rbln/transformers/models/mistral/modeling_mistral.py,sha256=i3X3HKGNee2ocEmpxdHMxuq7UAOgUs-QWlq2OizqA4g,1954
|
76
|
+
optimum/rbln/transformers/models/phi/__init__.py,sha256=LrGFTUo1oQnsPSTlxJqAJVVNUdUwq4u_Bf60RUgjLz4,1038
|
77
|
+
optimum/rbln/transformers/models/phi/modeling_phi.py,sha256=JfpuUB6cign-lqcUoprgq3gbQclZFT9HGV-NYVkSads,1910
|
78
|
+
optimum/rbln/transformers/models/phi/phi_architecture.py,sha256=GTmqWF6cJn4rhFTkuYxuzEcVTE-fM4dfpV0Ve1Abi9Q,4440
|
79
|
+
optimum/rbln/transformers/models/qwen2/__init__.py,sha256=1PLl1rlF14C6eSk3EZaDfyEHPaC4DZ2vwVlrklTkOYg,1042
|
80
|
+
optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=8ldxWKk85snFX_EViA7kgcgKAZ_QSbmQxhlO4yFvhOA,1924
|
81
|
+
optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=-X9OZ4HUCYDtwKnvidkWzCMPh_Xuu1wj-wRXIsQ9Pjg,1115
|
82
|
+
optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=Oa11lBWDNQWahqvDco3JIsZldYS-lO8qjpnaGKSfR00,1045
|
83
|
+
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=GSy_9bVdiZ4kavavsX-UJ62RIQn18_7k8i8PN6G2P9E,16131
|
84
|
+
optimum/rbln/transformers/models/t5/__init__.py,sha256=H1ns7mquQDkImSI1KT4oTe4owK4s_n28YqVxmZ31TF0,1133
|
85
|
+
optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=jBZDKwWXbuKEHUraU0N7P-XUPxznpCNdaRYs4buJY0Y,7776
|
86
|
+
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=k3ROGNSGGuF1gFNV-LxoFFgfxo7ab5GSQA4GIi5MLsI,21074
|
87
|
+
optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=mz4cXqG9b0tDpTAw3qYn3FaJuolX601VmKBE3gohLSw,1043
|
88
|
+
optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=vRDKoujGRvMvyAlJVJhj_EHf0OtQWbZHickFLzGjmDI,4231
|
89
|
+
optimum/rbln/transformers/models/whisper/__init__.py,sha256=PZ8qeAAFMas2MizwVYFxlpFWd5k1Pe1x-0IJfYAMhT8,1059
|
90
|
+
optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=Kwwskbp48wJxEkFGQLlm0L252rO7tx_YLYmOA-_IPwI,3387
|
91
|
+
optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=DX4tBJxULJY_UCm1Tw4EiGn1FjZioBjZZbFAB1Uipm4,15355
|
92
|
+
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=OQzkGa2WSUn3OVQ1DYVOY49N46QvxO1hdEbQ7Ke-o_c,17203
|
93
|
+
optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=NTj4hCpd8L2_i5DZuV5wp-h8OlTLYVUqTrJxzY_Dg9g,1047
|
94
|
+
optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=H5SEtmCAuG6pL1ovl4eLPGZ3tx1IPOilsxKvnbFDN-E,3821
|
95
|
+
optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
96
|
+
optimum/rbln/transformers/utils/rbln_quantization.py,sha256=-abKYe20hnwx1RtPE2Yz7C5slEKYmKohBSscoRoA2bo,7807
|
97
|
+
optimum/rbln/utils/__init__.py,sha256=F6hJP00eV1_hT_IVwqqYwLWcLQAvZbmmrNMJTia3mjI,1106
|
98
|
+
optimum/rbln/utils/decorator_utils.py,sha256=KDxCPC6G1Au8nokvTGjo--JyZbaWjOLzfJl_oewJ5oQ,2311
|
99
|
+
optimum/rbln/utils/hub.py,sha256=o-nA2I_jnB0S2AX0-q9lSpSNcdP_TeqZbHv84Gxxvi4,4592
|
100
|
+
optimum/rbln/utils/import_utils.py,sha256=fk8VIS46sB5zNqinfpmJLINjcJhTRSf-xdvp-g98Aps,4287
|
101
|
+
optimum/rbln/utils/logging.py,sha256=xIcLmUQoIJoBj3owkXN5_WQkQljcos6J6KSdX35IApw,2271
|
102
|
+
optimum/rbln/utils/model_utils.py,sha256=cnQbNtc2KUAJOcB6rHVwF8RpCNJFOTiCa91AQlUlgMM,1673
|
103
|
+
optimum/rbln/utils/runtime_utils.py,sha256=DXzRJKvLkiRYspefJsps5mHDpgQl_skA1BfIADsXPTg,3815
|
104
|
+
optimum/rbln/utils/save_utils.py,sha256=eFIPtmiblCJ3MvtxEPxmAR3iuLEUrzpyzwtVotDauhw,3283
|
105
|
+
optimum/rbln/utils/submodule.py,sha256=UHizJSL3osA5Jiaarjbvl7AUWlXp4p8Pb_9JZKsaoCI,3472
|
106
|
+
optimum/rbln/utils/timer_utils.py,sha256=o6EI-7-pcr3LhvCGJ1HIs1KH17yF2CaNpTsbHHbHmzc,1229
|
107
|
+
optimum_rbln-0.1.15.dist-info/METADATA,sha256=4Zxw1eSnrtAUDNrEEkOhMal7Ryh2CQ4niAKs-9I-dbc,4248
|
108
|
+
optimum_rbln-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
109
|
+
optimum_rbln-0.1.15.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
110
|
+
optimum_rbln-0.1.15.dist-info/RECORD,,
|
@@ -1,139 +0,0 @@
|
|
1
|
-
# Copyright 2024 Rebellions Inc.
|
2
|
-
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at:
|
6
|
-
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
|
-
from typing import List, Optional
|
25
|
-
|
26
|
-
import torch
|
27
|
-
from transformers import AutoTokenizer, TextIteratorStreamer
|
28
|
-
|
29
|
-
|
30
|
-
class BatchTextIteratorStreamer(TextIteratorStreamer):
|
31
|
-
"""
|
32
|
-
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
|
33
|
-
useful for applications that benefit from accessing the generated text in a non-blocking way (e.g., in an interactive
|
34
|
-
Gradio demo).
|
35
|
-
|
36
|
-
This iterator extends TextIteratorStreamer to support batching of text generation. Each put operation appends
|
37
|
-
generated text to a batch, and the end operation finalizes the batch by processing and storing the generated
|
38
|
-
sequences.
|
39
|
-
|
40
|
-
Parameters:
|
41
|
-
batch_size (int):
|
42
|
-
The size of each text generation batch.
|
43
|
-
tokenizer (AutoTokenizer):
|
44
|
-
The tokenizer used to decode the tokens.
|
45
|
-
skip_prompt (bool, optional, default=False):
|
46
|
-
Whether to skip the prompt to `.generate()` or not. Useful, for example, for chatbots.
|
47
|
-
timeout (float, optional):
|
48
|
-
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
49
|
-
in `.generate()` when it is called in a separate thread.
|
50
|
-
**decode_kwargs (dict, optional):
|
51
|
-
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
52
|
-
|
53
|
-
"""
|
54
|
-
|
55
|
-
def __init__(
|
56
|
-
self,
|
57
|
-
batch_size: int,
|
58
|
-
tokenizer: "AutoTokenizer",
|
59
|
-
skip_prompt: bool = False,
|
60
|
-
timeout: Optional[float] = None,
|
61
|
-
**decode_kwargs,
|
62
|
-
):
|
63
|
-
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
|
64
|
-
self.batch_size: int = batch_size
|
65
|
-
self.token_cache: List[List[int]] = [[] for _ in range(batch_size)]
|
66
|
-
self.print_len = [0] * batch_size
|
67
|
-
self.blocked = False
|
68
|
-
|
69
|
-
def put(self, value):
|
70
|
-
"""
|
71
|
-
Receives tokens, decodes them, and prints them to buffer as soon as they form entire words.
|
72
|
-
"""
|
73
|
-
|
74
|
-
if len(value.shape) < 2:
|
75
|
-
value = torch.reshape(value, (self.batch_size, value.shape[0] // self.batch_size))
|
76
|
-
|
77
|
-
if self.skip_prompt and self.next_tokens_are_prompt:
|
78
|
-
self.next_tokens_are_prompt = False
|
79
|
-
return
|
80
|
-
|
81
|
-
batch_printable_text = []
|
82
|
-
for i in range(self.batch_size):
|
83
|
-
# Add the new token to the cache and decodes the entire thing
|
84
|
-
self.token_cache[i].extend(value[i].tolist())
|
85
|
-
text = self.tokenizer.decode(self.token_cache[i], **self.decode_kwargs)
|
86
|
-
|
87
|
-
# After the symbol for a new line, we flush the cache.
|
88
|
-
if text.endswith("\n"):
|
89
|
-
printable_text = text[self.print_len[i] :]
|
90
|
-
self.token_cache[i] = []
|
91
|
-
self.print_len[i] = 0
|
92
|
-
# If the last token is a CJK character, we print the characters.
|
93
|
-
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
94
|
-
printable_text = text[self.print_len[i] :]
|
95
|
-
self.print_len[i] += len(printable_text)
|
96
|
-
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
97
|
-
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
98
|
-
else:
|
99
|
-
printable_text = text[self.print_len[i] : text.rfind(" ") + 1]
|
100
|
-
self.print_len[i] += len(printable_text)
|
101
|
-
batch_printable_text.append(printable_text)
|
102
|
-
|
103
|
-
self.on_finalized_text(batch_printable_text)
|
104
|
-
|
105
|
-
def end(self):
|
106
|
-
"""Flushes any remaining cache and prints a newline to stdout."""
|
107
|
-
batch_printable_text = []
|
108
|
-
for idx in range(self.batch_size):
|
109
|
-
if len(self.token_cache[idx]) > 0:
|
110
|
-
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
|
111
|
-
printable_text = text[self.print_len[idx] :]
|
112
|
-
self.token_cache[idx] = []
|
113
|
-
self.print_len[idx] = 0
|
114
|
-
else:
|
115
|
-
printable_text = ""
|
116
|
-
batch_printable_text.append(printable_text)
|
117
|
-
|
118
|
-
self.next_tokens_are_prompt = True
|
119
|
-
self.on_finalized_text(batch_printable_text, stream_end=True)
|
120
|
-
self.blocked = False
|
121
|
-
|
122
|
-
def on_finalized_text(self, texts: List[str], stream_end: bool = False):
|
123
|
-
self.text_queue.put(texts, timeout=self.timeout)
|
124
|
-
if stream_end:
|
125
|
-
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
126
|
-
|
127
|
-
# thkim change for demo
|
128
|
-
def __next__(self):
|
129
|
-
value = self.text_queue.get(timeout=self.timeout)
|
130
|
-
if value == self.stop_signal:
|
131
|
-
raise StopIteration()
|
132
|
-
else:
|
133
|
-
return value
|
134
|
-
|
135
|
-
def block(self):
|
136
|
-
self.blocked = True
|
137
|
-
|
138
|
-
def is_blocked(self):
|
139
|
-
return self.blocked
|