qwen-tts 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qwen_tts/core/models/modeling_qwen3_tts.py +63 -17
- qwen_tts-0.0.4.dist-info/METADATA +1384 -0
- {qwen_tts-0.0.3.dist-info → qwen_tts-0.0.4.dist-info}/RECORD +7 -7
- qwen_tts-0.0.3.dist-info/METADATA +0 -29
- {qwen_tts-0.0.3.dist-info → qwen_tts-0.0.4.dist-info}/WHEEL +0 -0
- {qwen_tts-0.0.3.dist-info → qwen_tts-0.0.4.dist-info}/entry_points.txt +0 -0
- {qwen_tts-0.0.3.dist-info → qwen_tts-0.0.4.dist-info}/licenses/LICENSE +0 -0
- {qwen_tts-0.0.3.dist-info → qwen_tts-0.0.4.dist-info}/top_level.txt +0 -0
|
@@ -19,7 +19,9 @@ import os
|
|
|
19
19
|
from dataclasses import dataclass
|
|
20
20
|
from typing import Callable, Optional
|
|
21
21
|
|
|
22
|
+
import huggingface_hub
|
|
22
23
|
import torch
|
|
24
|
+
from huggingface_hub import snapshot_download
|
|
23
25
|
from librosa.filters import mel as librosa_mel_fn
|
|
24
26
|
from torch import nn
|
|
25
27
|
from torch.nn import functional as F
|
|
@@ -27,34 +29,69 @@ from transformers.activations import ACT2FN
|
|
|
27
29
|
from transformers.cache_utils import Cache, DynamicCache
|
|
28
30
|
from transformers.generation import GenerationMixin
|
|
29
31
|
from transformers.integrations import use_kernel_forward_from_hub
|
|
30
|
-
from transformers.masking_utils import (
|
|
31
|
-
|
|
32
|
-
create_sliding_window_causal_mask,
|
|
33
|
-
)
|
|
32
|
+
from transformers.masking_utils import (create_causal_mask,
|
|
33
|
+
create_sliding_window_causal_mask)
|
|
34
34
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
35
35
|
from transformers.modeling_layers import GradientCheckpointingLayer
|
|
36
|
-
from transformers.modeling_outputs import (
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
36
|
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
|
37
|
+
CausalLMOutputWithPast, ModelOutput)
|
|
38
|
+
from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
|
|
39
|
+
dynamic_rope_update)
|
|
40
|
+
from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
|
|
41
|
+
PreTrainedModel)
|
|
43
42
|
from transformers.processing_utils import Unpack
|
|
44
43
|
from transformers.utils import can_return_tuple, logging
|
|
45
44
|
from transformers.utils.hub import cached_file
|
|
46
45
|
|
|
47
46
|
from ...inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
|
48
|
-
from .configuration_qwen3_tts import (
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
Qwen3TTSTalkerConfig,
|
|
53
|
-
)
|
|
47
|
+
from .configuration_qwen3_tts import (Qwen3TTSConfig,
|
|
48
|
+
Qwen3TTSSpeakerEncoderConfig,
|
|
49
|
+
Qwen3TTSTalkerCodePredictorConfig,
|
|
50
|
+
Qwen3TTSTalkerConfig)
|
|
54
51
|
|
|
55
52
|
logger = logging.get_logger(__name__)
|
|
56
53
|
|
|
57
54
|
|
|
55
|
+
def download_weights_from_hf_specific(
|
|
56
|
+
model_name_or_path: str,
|
|
57
|
+
cache_dir: str | None,
|
|
58
|
+
allow_patterns: list[str],
|
|
59
|
+
revision: str | None = None,
|
|
60
|
+
ignore_patterns: str | list[str] | None = None,
|
|
61
|
+
) -> str:
|
|
62
|
+
"""Download model weights from Hugging Face Hub. Users can specify the
|
|
63
|
+
allow_patterns to download only the necessary weights.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
model_name_or_path (str): The model name or path.
|
|
67
|
+
cache_dir (Optional[str]): The cache directory to store the model
|
|
68
|
+
weights. If None, will use HF defaults.
|
|
69
|
+
allow_patterns (list[str]): The allowed patterns for the
|
|
70
|
+
weight files. Files matched by any of the patterns will be
|
|
71
|
+
downloaded.
|
|
72
|
+
revision (Optional[str]): The revision of the model.
|
|
73
|
+
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
|
|
74
|
+
filter out the weight files. Files matched by any of the patterns
|
|
75
|
+
will be ignored.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
str: The path to the downloaded model weights.
|
|
79
|
+
"""
|
|
80
|
+
assert len(allow_patterns) > 0
|
|
81
|
+
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
|
|
82
|
+
|
|
83
|
+
for allow_pattern in allow_patterns:
|
|
84
|
+
hf_folder = snapshot_download(
|
|
85
|
+
model_name_or_path,
|
|
86
|
+
allow_patterns=allow_pattern,
|
|
87
|
+
ignore_patterns=ignore_patterns,
|
|
88
|
+
cache_dir=cache_dir,
|
|
89
|
+
revision=revision,
|
|
90
|
+
local_files_only=local_only,
|
|
91
|
+
)
|
|
92
|
+
return hf_folder
|
|
93
|
+
|
|
94
|
+
|
|
58
95
|
class Res2NetBlock(torch.nn.Module):
|
|
59
96
|
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
|
|
60
97
|
super().__init__()
|
|
@@ -1846,6 +1883,15 @@ class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin)
|
|
|
1846
1883
|
weights_only=weights_only,
|
|
1847
1884
|
**kwargs,
|
|
1848
1885
|
)
|
|
1886
|
+
if not local_files_only and not os.path.isdir(pretrained_model_name_or_path):
|
|
1887
|
+
download_cache_dir = kwargs.get("cache_dir", cache_dir)
|
|
1888
|
+
download_revision = kwargs.get("revision", revision)
|
|
1889
|
+
download_weights_from_hf_specific(
|
|
1890
|
+
pretrained_model_name_or_path,
|
|
1891
|
+
cache_dir=download_cache_dir,
|
|
1892
|
+
allow_patterns=["speech_tokenizer/*"],
|
|
1893
|
+
revision=download_revision,
|
|
1894
|
+
)
|
|
1849
1895
|
speech_tokenizer_path = cached_file(
|
|
1850
1896
|
pretrained_model_name_or_path,
|
|
1851
1897
|
"speech_tokenizer/config.json",
|