optimum-rbln 0.2.1a2__py3-none-any.whl → 0.2.1a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/models/autoencoders/vae.py +2 -2
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -2
- optimum/rbln/modeling.py +2 -2
- optimum/rbln/modeling_base.py +35 -15
- optimum/rbln/transformers/models/bert/modeling_bert.py +2 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +2 -2
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -2
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +2 -2
- optimum/rbln/transformers/models/whisper/generation_whisper.py +19 -17
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +2 -2
- optimum/rbln/utils/save_utils.py +3 -2
- {optimum_rbln-0.2.1a2.dist-info → optimum_rbln-0.2.1a3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.2.1a2.dist-info → optimum_rbln-0.2.1a3.dist-info}/RECORD +24 -24
- {optimum_rbln-0.2.1a2.dist-info → optimum_rbln-0.2.1a3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.2.1a2.dist-info → optimum_rbln-0.2.1a3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
17
16
|
|
18
17
|
import rebel
|
@@ -23,6 +22,7 @@ from transformers import PretrainedConfig
|
|
23
22
|
|
24
23
|
from ....modeling import RBLNModel
|
25
24
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
25
|
+
from ....utils.logging import get_logger
|
26
26
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
27
27
|
from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
|
28
28
|
|
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
|
31
31
|
import torch
|
32
32
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
33
33
|
|
34
|
-
logger =
|
34
|
+
logger = get_logger(__name__)
|
35
35
|
|
36
36
|
|
37
37
|
class RBLNAutoencoderKL(RBLNModel):
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
from typing import TYPE_CHECKING
|
17
16
|
|
18
17
|
import torch # noqa: I001
|
@@ -20,13 +19,14 @@ from diffusers import AutoencoderKL
|
|
20
19
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
21
20
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
22
21
|
|
22
|
+
from ....utils.logging import get_logger
|
23
23
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
24
24
|
|
25
25
|
|
26
26
|
if TYPE_CHECKING:
|
27
27
|
import torch
|
28
28
|
|
29
|
-
logger =
|
29
|
+
logger = get_logger(__name__)
|
30
30
|
|
31
31
|
|
32
32
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import importlib
|
16
|
-
import logging
|
17
16
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
18
17
|
|
19
18
|
import torch
|
@@ -22,6 +21,7 @@ from transformers import PretrainedConfig
|
|
22
21
|
|
23
22
|
from ...modeling import RBLNModel
|
24
23
|
from ...modeling_config import RBLNCompileConfig, RBLNConfig
|
24
|
+
from ...utils.logging import get_logger
|
25
25
|
from ..modeling_diffusers import RBLNDiffusionMixin
|
26
26
|
|
27
27
|
|
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
|
29
29
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
30
30
|
|
31
31
|
|
32
|
-
logger =
|
32
|
+
logger = get_logger(__name__)
|
33
33
|
|
34
34
|
|
35
35
|
class _ControlNetModel(torch.nn.Module):
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
17
16
|
|
18
17
|
import torch
|
@@ -22,13 +21,14 @@ from transformers import PretrainedConfig
|
|
22
21
|
|
23
22
|
from ....modeling import RBLNModel
|
24
23
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
24
|
+
from ....utils.logging import get_logger
|
25
25
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
26
26
|
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
30
30
|
|
31
|
-
logger =
|
31
|
+
logger = get_logger(__name__)
|
32
32
|
|
33
33
|
|
34
34
|
class SD3Transformer2DModelWrapper(torch.nn.Module):
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
from dataclasses import dataclass
|
17
16
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
18
17
|
|
@@ -22,13 +21,14 @@ from transformers import PretrainedConfig
|
|
22
21
|
|
23
22
|
from ....modeling import RBLNModel
|
24
23
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
24
|
+
from ....utils.logging import get_logger
|
25
25
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
26
26
|
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
30
30
|
|
31
|
-
logger =
|
31
|
+
logger = get_logger(__name__)
|
32
32
|
|
33
33
|
|
34
34
|
class _UNet_SD(torch.nn.Module):
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
import os
|
17
16
|
from pathlib import Path
|
18
17
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
@@ -21,13 +20,14 @@ import torch
|
|
21
20
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
22
21
|
|
23
22
|
from ....modeling import RBLNModel
|
23
|
+
from ....utils.logging import get_logger
|
24
24
|
from ...models.controlnet import RBLNControlNetModel
|
25
25
|
|
26
26
|
|
27
27
|
if TYPE_CHECKING:
|
28
28
|
pass
|
29
29
|
|
30
|
-
logger =
|
30
|
+
logger = get_logger(__name__)
|
31
31
|
|
32
32
|
|
33
33
|
class RBLNMultiControlNetModel(RBLNModel):
|
@@ -34,16 +34,17 @@ from diffusers import StableDiffusionControlNetPipeline
|
|
34
34
|
from diffusers.image_processor import PipelineImageInput
|
35
35
|
from diffusers.pipelines.controlnet.pipeline_controlnet import retrieve_timesteps
|
36
36
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
37
|
-
from diffusers.utils import deprecate
|
37
|
+
from diffusers.utils import deprecate
|
38
38
|
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
39
39
|
|
40
40
|
from ....utils.decorator_utils import remove_compile_time_kwargs
|
41
|
+
from ....utils.logging import get_logger
|
41
42
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
42
43
|
from ...models import RBLNControlNetModel
|
43
44
|
from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
|
44
45
|
|
45
46
|
|
46
|
-
logger =
|
47
|
+
logger = get_logger(__name__)
|
47
48
|
|
48
49
|
|
49
50
|
class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionControlNetPipeline):
|
optimum/rbln/modeling.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
from pathlib import Path
|
17
16
|
from tempfile import TemporaryDirectory
|
18
17
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
@@ -24,13 +23,14 @@ from transformers import AutoConfig, PretrainedConfig
|
|
24
23
|
|
25
24
|
from .modeling_base import RBLNBaseModel
|
26
25
|
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, use_rbln_config
|
26
|
+
from .utils.logging import get_logger
|
27
27
|
|
28
28
|
|
29
29
|
if TYPE_CHECKING:
|
30
30
|
from transformers import PreTrainedModel
|
31
31
|
|
32
32
|
|
33
|
-
logger =
|
33
|
+
logger = get_logger(__name__)
|
34
34
|
|
35
35
|
|
36
36
|
class RBLNModel(RBLNBaseModel):
|
optimum/rbln/modeling_base.py
CHANGED
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import importlib
|
16
|
-
import logging
|
17
16
|
import os
|
18
17
|
import shutil
|
19
18
|
from abc import ABC, abstractmethod
|
@@ -32,6 +31,7 @@ from transformers import (
|
|
32
31
|
|
33
32
|
from .modeling_config import RBLNCompileConfig, RBLNConfig, use_rbln_config
|
34
33
|
from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
|
34
|
+
from .utils.logging import get_logger
|
35
35
|
from .utils.runtime_utils import UnavailableRuntime
|
36
36
|
from .utils.save_utils import maybe_load_preprocessors
|
37
37
|
from .utils.submodule import SubModulesMixin
|
@@ -40,7 +40,7 @@ from .utils.submodule import SubModulesMixin
|
|
40
40
|
if TYPE_CHECKING:
|
41
41
|
from transformers import PreTrainedModel
|
42
42
|
|
43
|
-
logger =
|
43
|
+
logger = get_logger(__name__)
|
44
44
|
|
45
45
|
|
46
46
|
class PreTrainedModel(ABC): # noqa: F811
|
@@ -442,27 +442,47 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
442
442
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
443
443
|
return
|
444
444
|
|
445
|
-
os.makedirs(save_directory, exist_ok=True)
|
446
|
-
|
447
445
|
real_save_dir = self.model_save_dir / self.subfolder
|
448
446
|
save_directory_path = Path(save_directory)
|
449
|
-
|
450
|
-
|
451
|
-
raise FileExistsError(
|
452
|
-
f"Cannot save model to '{save_directory}'. "
|
453
|
-
f"This directory already exists and contains the model files."
|
454
|
-
)
|
455
|
-
shutil.copytree(real_save_dir, save_directory, dirs_exist_ok=True)
|
456
|
-
self.config.save_pretrained(save_directory)
|
457
|
-
if self.generation_config is not None:
|
458
|
-
self.generation_config.save_pretrained(save_directory)
|
459
|
-
else:
|
447
|
+
|
448
|
+
if not os.path.exists(real_save_dir) or not os.path.isdir(real_save_dir):
|
460
449
|
raise FileNotFoundError(
|
461
450
|
f"Unable to save the model. The model directory '{real_save_dir}' does not exist or is not accessible. "
|
462
451
|
f"Cannot save to the specified destination '{save_directory}'. "
|
463
452
|
f"Please ensure the model directory exists and you have the necessary permissions to access it."
|
464
453
|
)
|
465
454
|
|
455
|
+
if save_directory_path.absolute() == real_save_dir.absolute():
|
456
|
+
raise FileExistsError(
|
457
|
+
f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
|
458
|
+
)
|
459
|
+
|
460
|
+
# Create a temporary directory next to the target directory
|
461
|
+
tmp_dir = save_directory + ".tmp"
|
462
|
+
try:
|
463
|
+
# Remove temporary directory if it exists from a previous failed attempt
|
464
|
+
if os.path.exists(tmp_dir):
|
465
|
+
shutil.rmtree(tmp_dir)
|
466
|
+
|
467
|
+
# First copy everything to a temporary directory
|
468
|
+
shutil.copytree(real_save_dir, tmp_dir)
|
469
|
+
|
470
|
+
# Save configs to the temporary directory
|
471
|
+
self.config.save_pretrained(tmp_dir)
|
472
|
+
if self.generation_config is not None:
|
473
|
+
self.generation_config.save_pretrained(tmp_dir)
|
474
|
+
|
475
|
+
# If everything succeeded, atomically replace the target directory
|
476
|
+
if os.path.exists(save_directory):
|
477
|
+
shutil.rmtree(save_directory)
|
478
|
+
os.rename(tmp_dir, save_directory)
|
479
|
+
|
480
|
+
except Exception as e:
|
481
|
+
# Clean up the temporary directory if anything fails
|
482
|
+
if os.path.exists(tmp_dir):
|
483
|
+
shutil.rmtree(tmp_dir)
|
484
|
+
raise e # Re-raise the exception after cleanup
|
485
|
+
|
466
486
|
if push_to_hub:
|
467
487
|
return super().push_to_hub(save_directory, **kwargs)
|
468
488
|
|
@@ -13,17 +13,17 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
import logging
|
17
16
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
18
17
|
|
19
18
|
from transformers import PretrainedConfig
|
20
19
|
|
21
20
|
from ....modeling import RBLNModel
|
22
21
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
22
|
+
from ....utils.logging import get_logger
|
23
23
|
from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForQuestionAnswering
|
24
24
|
|
25
25
|
|
26
|
-
logger =
|
26
|
+
logger = get_logger(__name__)
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
17
16
|
|
18
17
|
import torch
|
@@ -28,9 +27,10 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
|
28
27
|
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
|
29
28
|
from ....modeling import RBLNModel
|
30
29
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
30
|
+
from ....utils.logging import get_logger
|
31
31
|
|
32
32
|
|
33
|
-
logger =
|
33
|
+
logger = get_logger(__name__)
|
34
34
|
|
35
35
|
if TYPE_CHECKING:
|
36
36
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
17
16
|
|
18
17
|
from transformers import AutoModelForDepthEstimation
|
@@ -20,9 +19,10 @@ from transformers.modeling_outputs import DepthEstimatorOutput
|
|
20
19
|
|
21
20
|
from ....modeling import RBLNModel
|
22
21
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
22
|
+
from ....utils.logging import get_logger
|
23
23
|
|
24
24
|
|
25
|
-
logger =
|
25
|
+
logger = get_logger(__name__)
|
26
26
|
|
27
27
|
if TYPE_CHECKING:
|
28
28
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
import logging
|
17
16
|
from pathlib import Path
|
18
17
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
19
18
|
|
@@ -30,10 +29,11 @@ from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLM
|
|
30
29
|
|
31
30
|
from ....modeling import RBLNModel
|
32
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
|
+
from ....utils.logging import get_logger
|
33
33
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
|
34
34
|
|
35
35
|
|
36
|
-
logger =
|
36
|
+
logger = get_logger(__name__)
|
37
37
|
|
38
38
|
if TYPE_CHECKING:
|
39
39
|
from transformers import (
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
import logging
|
17
16
|
from abc import ABC
|
18
17
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
19
18
|
|
@@ -25,10 +24,11 @@ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
|
25
24
|
|
26
25
|
from ....modeling import RBLNModel
|
27
26
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
27
|
+
from ....utils.logging import get_logger
|
28
28
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
29
29
|
|
30
30
|
|
31
|
-
logger =
|
31
|
+
logger = get_logger(__name__)
|
32
32
|
|
33
33
|
if TYPE_CHECKING:
|
34
34
|
from transformers import (
|
@@ -12,7 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import logging
|
16
15
|
from typing import TYPE_CHECKING, Any, Dict, Union
|
17
16
|
|
18
17
|
import torch
|
@@ -21,9 +20,10 @@ from transformers.modeling_outputs import CausalLMOutput
|
|
21
20
|
|
22
21
|
from ....modeling import RBLNModel
|
23
22
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
23
|
+
from ....utils.logging import get_logger
|
24
24
|
|
25
25
|
|
26
|
-
logger =
|
26
|
+
logger = get_logger(__name__)
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
from transformers import (
|
@@ -32,6 +32,8 @@ Modified from `transformers.models.whisper.generation_whisper.py`
|
|
32
32
|
"""
|
33
33
|
|
34
34
|
import torch
|
35
|
+
import transformers
|
36
|
+
from packaging import version
|
35
37
|
from transformers import GenerationMixin
|
36
38
|
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
|
37
39
|
|
@@ -47,17 +49,12 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
47
49
|
self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, *args, **kwargs
|
48
50
|
):
|
49
51
|
# remove all previously passed decoder input ids
|
50
|
-
|
51
|
-
|
52
|
-
# 4.40.2 has no keyword shortform, it has seperate codes from generation_fallback
|
53
|
-
is_shortform = kwargs.get("is_shortform", False)
|
54
|
-
start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
|
52
|
+
# should happen only if it is the first generated segment
|
53
|
+
start_idx = decoder_input_ids.shape[-1]
|
55
54
|
|
56
55
|
if isinstance(seek_outputs, torch.Tensor):
|
57
|
-
|
58
|
-
return seek_outputs, seek_outputs
|
56
|
+
return seek_outputs[:, start_idx:], seek_outputs
|
59
57
|
|
60
|
-
############## rbln validation#############
|
61
58
|
if return_token_timestamps and not self.rbln_token_timestamps:
|
62
59
|
raise RuntimeError(
|
63
60
|
"To use .generate() with return_token_timestamps=True, the model must be compiled with rbln_token_timestamps=True. "
|
@@ -67,11 +64,19 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
67
64
|
|
68
65
|
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
69
66
|
num_frames = getattr(generation_config, "num_frames", None)
|
70
|
-
|
71
|
-
seek_outputs
|
72
|
-
|
73
|
-
|
74
|
-
|
67
|
+
if version.parse(transformers.__version__) >= version.parse("4.46.0"):
|
68
|
+
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
69
|
+
seek_outputs,
|
70
|
+
generation_config.alignment_heads,
|
71
|
+
num_frames=num_frames,
|
72
|
+
num_input_ids=decoder_input_ids.shape[-1],
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
76
|
+
seek_outputs,
|
77
|
+
generation_config.alignment_heads,
|
78
|
+
num_frames=num_frames,
|
79
|
+
)
|
75
80
|
seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
|
76
81
|
|
77
82
|
def split_by_batch_index(values, key, batch_idx):
|
@@ -87,15 +92,12 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
87
92
|
|
88
93
|
sequence_tokens = seek_outputs["sequences"]
|
89
94
|
|
90
|
-
##################################### thkim change #############################################
|
91
95
|
valid_seekoutputs = []
|
92
96
|
for k, v in seek_outputs.items():
|
93
97
|
if v is not None and len(v) > 0 and v[0] is not None:
|
94
98
|
valid_seekoutputs.append((k, v))
|
95
99
|
seek_outputs = [
|
96
|
-
{k: split_by_batch_index(v, k, i) for k, v in valid_seekoutputs}
|
97
|
-
# {k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
|
98
|
-
for i in range(sequence_tokens.shape[0])
|
100
|
+
{k: split_by_batch_index(v, k, i) for k, v in valid_seekoutputs} for i in range(sequence_tokens.shape[0])
|
99
101
|
]
|
100
102
|
|
101
103
|
return sequence_tokens, seek_outputs
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
import logging
|
17
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
18
17
|
|
19
18
|
import rebel
|
@@ -30,12 +29,13 @@ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
|
30
29
|
|
31
30
|
from ....modeling import RBLNModel
|
32
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
|
+
from ....utils.logging import get_logger
|
33
33
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
34
34
|
from .generation_whisper import RBLNWhisperGenerationMixin
|
35
35
|
from .whisper_architecture import WhisperWrapper
|
36
36
|
|
37
37
|
|
38
|
-
logger =
|
38
|
+
logger = get_logger(__name__)
|
39
39
|
|
40
40
|
if TYPE_CHECKING:
|
41
41
|
from transformers import AutoFeatureExtractor, AutoProcessor, PretrainedConfig, PreTrainedModel
|
@@ -13,7 +13,6 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
import logging
|
17
16
|
from typing import TYPE_CHECKING, Optional, Union
|
18
17
|
|
19
18
|
import torch
|
@@ -21,9 +20,10 @@ from transformers import PretrainedConfig
|
|
21
20
|
|
22
21
|
from ....modeling import RBLNModel
|
23
22
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
23
|
+
from ....utils.logging import get_logger
|
24
24
|
|
25
25
|
|
26
|
-
logger =
|
26
|
+
logger = get_logger(__name__)
|
27
27
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
optimum/rbln/utils/save_utils.py
CHANGED
@@ -30,14 +30,15 @@
|
|
30
30
|
Refer to huggingface/optimum/blob/4fdeea77d71e79451ba53e0c1f9d8f37e9704268/optimum/utils/save_utils.py
|
31
31
|
"""
|
32
32
|
|
33
|
-
import logging
|
34
33
|
from pathlib import Path
|
35
34
|
from typing import List, Union
|
36
35
|
|
37
36
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
38
37
|
|
38
|
+
from .logging import get_logger
|
39
39
|
|
40
|
-
|
40
|
+
|
41
|
+
logger = get_logger(__name__)
|
41
42
|
|
42
43
|
|
43
44
|
def maybe_load_preprocessors(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.1a3
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|
@@ -1,23 +1,23 @@
|
|
1
1
|
optimum/rbln/__init__.py,sha256=sLCjJu_MLZEKDOwHIlJP4u4GzGZx-1kqHTYGw5B4xDg,6096
|
2
|
-
optimum/rbln/__version__.py,sha256=
|
3
|
-
optimum/rbln/modeling.py,sha256=
|
4
|
-
optimum/rbln/modeling_base.py,sha256=
|
2
|
+
optimum/rbln/__version__.py,sha256=Qa8tLTuiehljsgp_ibSY6aee43cZYh5J_fQ5zMTZ6SA,413
|
3
|
+
optimum/rbln/modeling.py,sha256=REImAAKO82CqSNABR-9E1jJEsWch9amSOwOOQhFEYLY,8283
|
4
|
+
optimum/rbln/modeling_base.py,sha256=_5M8hVySDwCJ6qfeku2_nJAPu_5JLfAUu3HO1bc3ALM,21098
|
5
5
|
optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
|
6
6
|
optimum/rbln/diffusers/__init__.py,sha256=68FTAMpbbMflm8qiSqfM5J2_gFb3iU3fng6AL0TG47A,2913
|
7
7
|
optimum/rbln/diffusers/modeling_diffusers.py,sha256=E1x-iOKEJCUB6ml0RgtFEVPPk6J6pqEF-JTEyOZzOyc,14928
|
8
8
|
optimum/rbln/diffusers/models/__init__.py,sha256=aSL5_yd-y8Q6DxNvfQ-yl-BUNyMzI1P6AikjQMKZzpI,1357
|
9
|
-
optimum/rbln/diffusers/models/controlnet.py,sha256=
|
9
|
+
optimum/rbln/diffusers/models/controlnet.py,sha256=EM_HlzCdaZdnnK0oGpY2fQeigPqHhlwh4NHCzlmoumI,10512
|
10
10
|
optimum/rbln/diffusers/models/autoencoders/__init__.py,sha256=nMfnwEwuOje-qKofAw-uOWUWcYV_YvnaN68IGfDdqHg,645
|
11
|
-
optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py,sha256=
|
12
|
-
optimum/rbln/diffusers/models/autoencoders/vae.py,sha256=
|
11
|
+
optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py,sha256=rCbC32bJnfXtsLdVvNVVHpRAkCYy6jeCSwIZ-JSReWk,9220
|
12
|
+
optimum/rbln/diffusers/models/autoencoders/vae.py,sha256=A-F2TRJ2vL4gNXiMT_hRGeanIFKWxJ1QaKmYVp41rwI,2513
|
13
13
|
optimum/rbln/diffusers/models/transformers/__init__.py,sha256=TEhARgQJx_NUZzI6M8gt3aWbdzmLHnM6FMSQd9M9zCk,654
|
14
|
-
optimum/rbln/diffusers/models/transformers/transformer_sd3.py,sha256=
|
14
|
+
optimum/rbln/diffusers/models/transformers/transformer_sd3.py,sha256=n_krmMgiRxWrG--567PNpk58EG_X7x7H4gidIkRvwjo,7308
|
15
15
|
optimum/rbln/diffusers/models/unets/__init__.py,sha256=MaICuK9CWjgzejXy8y2NDrphuEq1rkzanF8u45k6O5I,655
|
16
|
-
optimum/rbln/diffusers/models/unets/unet_2d_condition.py,sha256=
|
16
|
+
optimum/rbln/diffusers/models/unets/unet_2d_condition.py,sha256=Z0-eAZw1Gah24y6uOO5m9-GRruBppCSdV2NQZLNtBaI,14021
|
17
17
|
optimum/rbln/diffusers/pipelines/__init__.py,sha256=i8AQJSoV9clLTill7wP5ECci6E7lC2gBaNuqfhYklZk,2469
|
18
18
|
optimum/rbln/diffusers/pipelines/controlnet/__init__.py,sha256=n1Ef22TSeax-kENi_d8K6wGGHSNEo9QkUeygELHgcao,983
|
19
|
-
optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py,sha256=
|
20
|
-
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py,sha256=
|
19
|
+
optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py,sha256=JWKtnZYBIfgmbAo0SLFIvHBQCv2BPSFNvpcdjG4GUOY,4113
|
20
|
+
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py,sha256=dGdw5cwJLS4CLv6IHskk5ZCcPgS7UDuHKbfOZ8ojNUs,35187
|
21
21
|
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py,sha256=7xCiXrH4ToCTHohVGFXqO7_f9G8HShYaHgZxoMZARkQ,33664
|
22
22
|
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py,sha256=Gzt2wg4dgFg0TV3Bu0cs8Xru3wVrxWUxxgciwZ-QKLE,44755
|
23
23
|
optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py,sha256=RfwxNX_zQWFtvvFQJ5bt3qtHbdYdQV_3XLHm9WYCKOs,46084
|
@@ -49,14 +49,14 @@ optimum/rbln/transformers/models/bart/__init__.py,sha256=32HPe0_GIO0hp9U464Iv6Jd
|
|
49
49
|
optimum/rbln/transformers/models/bart/bart_architecture.py,sha256=dTkgMpNkyh4vT_mZU5tQ5bvH_lRZfRjaJ1gIHvJkmgs,5479
|
50
50
|
optimum/rbln/transformers/models/bart/modeling_bart.py,sha256=ADRbE-5N3xJ60AzzjJ4BZs_THmB71qs4XTr9iFqsEqE,5667
|
51
51
|
optimum/rbln/transformers/models/bert/__init__.py,sha256=YVV7k_laU6yJBawZrgjIWjRmIF-Y4oQQHqyf8lsraQs,691
|
52
|
-
optimum/rbln/transformers/models/bert/modeling_bert.py,sha256
|
52
|
+
optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=-nv-sgmHkyHQIoQvF8-lXOJiL4eaa1pq8MpdN4uRi9M,4668
|
53
53
|
optimum/rbln/transformers/models/clip/__init__.py,sha256=ssJqlEt318ti2QaEakGh_tO3Ap1VSPCVF-ymUuvjAJs,698
|
54
|
-
optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=
|
54
|
+
optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=E1QfVNq1sTCp7uvuha1ZPfXMwvMTkGV9L4oFdmy1w4g,5724
|
55
55
|
optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
|
56
56
|
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=BjQHwoPZfM-KUQzxm4AU-PdmoMgLxnCG6kfSpGjUvrk,36578
|
57
57
|
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=mAgRRMGVHvTUjJBDlmUOjNhSNjprKSD7tLeFknrx0Rw,25810
|
58
58
|
optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
|
59
|
-
optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=
|
59
|
+
optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
|
60
60
|
optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
|
61
61
|
optimum/rbln/transformers/models/exaone/exaone_architecture.py,sha256=thzWLVz3eUcst4IPiOavta5QeXZw7JQwwfdIzQ_x6Ns,3029
|
62
62
|
optimum/rbln/transformers/models/exaone/modeling_exaone.py,sha256=WjyH8PmsMljSea7kJn_Cq1FJ96OXwXAoU7hv2Q8zUnI,1747
|
@@ -70,7 +70,7 @@ optimum/rbln/transformers/models/llama/__init__.py,sha256=jo_j_eIrHYGNEhR5lb6g3r
|
|
70
70
|
optimum/rbln/transformers/models/llama/llama_architecture.py,sha256=S7MCPfyjG5eUqgaS-QNBB0ApUD6wnb5fR0RHq7k7-pA,728
|
71
71
|
optimum/rbln/transformers/models/llama/modeling_llama.py,sha256=Z3iony7icoFhRQ11MAuFx9UF03uJCsvJQZ6bxHXlrgk,1530
|
72
72
|
optimum/rbln/transformers/models/llava_next/__init__.py,sha256=VLieyWm-UgvuNxw9B38wrL1Jsa09NBDX_ebABmdpTbs,670
|
73
|
-
optimum/rbln/transformers/models/llava_next/modeling_llava_next.py,sha256=
|
73
|
+
optimum/rbln/transformers/models/llava_next/modeling_llava_next.py,sha256=_8zKsI-Kj4bbsPLnERJqg-0oC6EyAWrmnxvszsAtRaA,26398
|
74
74
|
optimum/rbln/transformers/models/midm/__init__.py,sha256=UJSaErsF-z6dZERIS143WTaygffZyzEGqoQ2ZPDiM-c,855
|
75
75
|
optimum/rbln/transformers/models/midm/midm_architecture.py,sha256=mueRmMGX6UplZb0C0RFdUOa9lsNH8YJHV6rYrDLOdlQ,5302
|
76
76
|
optimum/rbln/transformers/models/midm/modeling_midm.py,sha256=GG25BozEZriAL-OPFGpzOjyDtSFB-NfeiLJTDAqxe20,1734
|
@@ -84,19 +84,19 @@ optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxP
|
|
84
84
|
optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
|
85
85
|
optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
|
86
86
|
optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
|
87
|
-
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=
|
87
|
+
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=2hkCPvaiyS16zdtUiJKhvpk1qJfsXVLrAQPgAtixCg0,15426
|
88
88
|
optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=15yoF-wyhcLcK-Z2MOUmyPlkOMNTVOJ013uBepqtpxA,18387
|
89
89
|
optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
|
90
90
|
optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=MFs-3yYviV1QqSpsTB2GarTEs9wGH5AYofksLQLMBXg,8043
|
91
91
|
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=kkjErS42mW2jv5O_xL7BaKobvvqy7BGmYOowKyHakvI,7189
|
92
92
|
optimum/rbln/transformers/models/wav2vec2/__init__.py,sha256=YpgA0K-vyg9veh0eL_jxauosbRpb_kpGKHvvQLBspKM,649
|
93
|
-
optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=
|
93
|
+
optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py,sha256=JYJmV52j6cBwim4RanVJryfKnV80V96ol0A-oR6o7cg,3856
|
94
94
|
optimum/rbln/transformers/models/whisper/__init__.py,sha256=ktnNe5ri3ycCWZ_W_voFB9y9-vgGgxS1X9s8LBRZmWc,665
|
95
|
-
optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=
|
96
|
-
optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=
|
95
|
+
optimum/rbln/transformers/models/whisper/generation_whisper.py,sha256=GIHTca3b1VtW81kp7BzKQ7f77c2t9OsEsbZetripgDo,4582
|
96
|
+
optimum/rbln/transformers/models/whisper/modeling_whisper.py,sha256=0nBADNxE0A1ozBbRutTBvxpo_Y1qkOycT_zronkN-ZU,15840
|
97
97
|
optimum/rbln/transformers/models/whisper/whisper_architecture.py,sha256=eP3UgkwCRaaFjc5Jc4ZEiWxr3-L7oJx9KzpJ7eFkwUs,13158
|
98
98
|
optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=fC7iNcdxBZ_6eOF2snStmf8r2M3c8O_-XcXnQEaHQCE,653
|
99
|
-
optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=
|
99
|
+
optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=lKSeL3RUwIyfuca2jZ6SFV4N59EJS4UD59JMUfh3BiA,4767
|
100
100
|
optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
101
101
|
optimum/rbln/transformers/utils/rbln_quantization.py,sha256=gwBVHf97sQgPNmGa0wq87E8mPyrtXYhMnO4X4sKp3c8,7639
|
102
102
|
optimum/rbln/utils/__init__.py,sha256=ieDBT2VFTt2E0M4v_POLBpuGW9LxSydpb_DuPd6PQqc,712
|
@@ -106,9 +106,9 @@ optimum/rbln/utils/import_utils.py,sha256=ec-tISKIjUPHIfjzj6p-w78NVejWVBohb59f7J
|
|
106
106
|
optimum/rbln/utils/logging.py,sha256=VKKBmlQSdg6iZCGmAXaWYiW67K84jyp1QJhLQSSjPPE,3453
|
107
107
|
optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvVxDvs,1278
|
108
108
|
optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
|
109
|
-
optimum/rbln/utils/save_utils.py,sha256=
|
109
|
+
optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
|
110
110
|
optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
|
111
|
-
optimum_rbln-0.2.
|
112
|
-
optimum_rbln-0.2.
|
113
|
-
optimum_rbln-0.2.
|
114
|
-
optimum_rbln-0.2.
|
111
|
+
optimum_rbln-0.2.1a3.dist-info/METADATA,sha256=umGg7JkKhTcNc5AOyzubqzpoPXnGY1WosDi48dfAROw,5300
|
112
|
+
optimum_rbln-0.2.1a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
113
|
+
optimum_rbln-0.2.1a3.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
114
|
+
optimum_rbln-0.2.1a3.dist-info/RECORD,,
|
File without changes
|
File without changes
|