qwen-tts 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,7 +19,9 @@ import os
19
19
  from dataclasses import dataclass
20
20
  from typing import Callable, Optional
21
21
 
22
+ import huggingface_hub
22
23
  import torch
24
+ from huggingface_hub import snapshot_download
23
25
  from librosa.filters import mel as librosa_mel_fn
24
26
  from torch import nn
25
27
  from torch.nn import functional as F
@@ -27,34 +29,69 @@ from transformers.activations import ACT2FN
27
29
  from transformers.cache_utils import Cache, DynamicCache
28
30
  from transformers.generation import GenerationMixin
29
31
  from transformers.integrations import use_kernel_forward_from_hub
30
- from transformers.masking_utils import (
31
- create_causal_mask,
32
- create_sliding_window_causal_mask,
33
- )
32
+ from transformers.masking_utils import (create_causal_mask,
33
+ create_sliding_window_causal_mask)
34
34
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
35
  from transformers.modeling_layers import GradientCheckpointingLayer
36
- from transformers.modeling_outputs import (
37
- BaseModelOutputWithPast,
38
- CausalLMOutputWithPast,
39
- ModelOutput,
40
- )
41
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
37
+ CausalLMOutputWithPast, ModelOutput)
38
+ from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
39
+ dynamic_rope_update)
40
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
41
+ PreTrainedModel)
43
42
  from transformers.processing_utils import Unpack
44
43
  from transformers.utils import can_return_tuple, logging
45
44
  from transformers.utils.hub import cached_file
46
45
 
47
46
  from ...inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
48
- from .configuration_qwen3_tts import (
49
- Qwen3TTSConfig,
50
- Qwen3TTSSpeakerEncoderConfig,
51
- Qwen3TTSTalkerCodePredictorConfig,
52
- Qwen3TTSTalkerConfig,
53
- )
47
+ from .configuration_qwen3_tts import (Qwen3TTSConfig,
48
+ Qwen3TTSSpeakerEncoderConfig,
49
+ Qwen3TTSTalkerCodePredictorConfig,
50
+ Qwen3TTSTalkerConfig)
54
51
 
55
52
  logger = logging.get_logger(__name__)
56
53
 
57
54
 
55
+ def download_weights_from_hf_specific(
56
+ model_name_or_path: str,
57
+ cache_dir: str | None,
58
+ allow_patterns: list[str],
59
+ revision: str | None = None,
60
+ ignore_patterns: str | list[str] | None = None,
61
+ ) -> str:
62
+ """Download model weights from Hugging Face Hub. Users can specify the
63
+ allow_patterns to download only the necessary weights.
64
+
65
+ Args:
66
+ model_name_or_path (str): The model name or path.
67
+ cache_dir (Optional[str]): The cache directory to store the model
68
+ weights. If None, will use HF defaults.
69
+ allow_patterns (list[str]): The allowed patterns for the
70
+ weight files. Files matched by any of the patterns will be
71
+ downloaded.
72
+ revision (Optional[str]): The revision of the model.
73
+ ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
74
+ filter out the weight files. Files matched by any of the patterns
75
+ will be ignored.
76
+
77
+ Returns:
78
+ str: The path to the downloaded model weights.
79
+ """
80
+ assert len(allow_patterns) > 0
81
+ local_only = huggingface_hub.constants.HF_HUB_OFFLINE
82
+
83
+ for allow_pattern in allow_patterns:
84
+ hf_folder = snapshot_download(
85
+ model_name_or_path,
86
+ allow_patterns=allow_pattern,
87
+ ignore_patterns=ignore_patterns,
88
+ cache_dir=cache_dir,
89
+ revision=revision,
90
+ local_files_only=local_only,
91
+ )
92
+ return hf_folder
93
+
94
+
58
95
  class Res2NetBlock(torch.nn.Module):
59
96
  def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
60
97
  super().__init__()
@@ -1846,6 +1883,15 @@ class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin)
1846
1883
  weights_only=weights_only,
1847
1884
  **kwargs,
1848
1885
  )
1886
+ if not local_files_only and not os.path.isdir(pretrained_model_name_or_path):
1887
+ download_cache_dir = kwargs.get("cache_dir", cache_dir)
1888
+ download_revision = kwargs.get("revision", revision)
1889
+ download_weights_from_hf_specific(
1890
+ pretrained_model_name_or_path,
1891
+ cache_dir=download_cache_dir,
1892
+ allow_patterns=["speech_tokenizer/*"],
1893
+ revision=download_revision,
1894
+ )
1849
1895
  speech_tokenizer_path = cached_file(
1850
1896
  pretrained_model_name_or_path,
1851
1897
  "speech_tokenizer/config.json",