soulxpodcast 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- soulxpodcast-0.1.0/PKG-INFO +20 -0
- soulxpodcast-0.1.0/README.md +0 -0
- soulxpodcast-0.1.0/pyproject.toml +32 -0
- soulxpodcast-0.1.0/setup.cfg +4 -0
- soulxpodcast-0.1.0/src/soulxpodcast/__init__.py +0 -0
- soulxpodcast-0.1.0/src/soulxpodcast/config.py +142 -0
- soulxpodcast-0.1.0/src/soulxpodcast/engine/__init__.py +0 -0
- soulxpodcast-0.1.0/src/soulxpodcast/engine/llm_engine.py +114 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/__init__.py +0 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/flow.py +197 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/flow_components/__init__.py +0 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/flow_components/estimator.py +974 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/flow_components/upsample_encoder.py +997 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/hifigan.py +249 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/hifigan_components/__init__.py +0 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/hifigan_components/layers.py +433 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/modules/sampler.py +221 -0
- soulxpodcast-0.1.0/src/soulxpodcast/models/soulxpodcast.py +168 -0
- soulxpodcast-0.1.0/src/soulxpodcast/utils/__init__.py +0 -0
- soulxpodcast-0.1.0/src/soulxpodcast/utils/audio.py +123 -0
- soulxpodcast-0.1.0/src/soulxpodcast/utils/commons.py +10 -0
- soulxpodcast-0.1.0/src/soulxpodcast/utils/dataloader.py +198 -0
- soulxpodcast-0.1.0/src/soulxpodcast/utils/infer_utils.py +95 -0
- soulxpodcast-0.1.0/src/soulxpodcast/utils/parser.py +87 -0
- soulxpodcast-0.1.0/src/soulxpodcast/utils/text.py +82 -0
- soulxpodcast-0.1.0/src/soulxpodcast.egg-info/PKG-INFO +20 -0
- soulxpodcast-0.1.0/src/soulxpodcast.egg-info/SOURCES.txt +28 -0
- soulxpodcast-0.1.0/src/soulxpodcast.egg-info/dependency_links.txt +1 -0
- soulxpodcast-0.1.0/src/soulxpodcast.egg-info/requires.txt +14 -0
- soulxpodcast-0.1.0/src/soulxpodcast.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: soulxpodcast
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: soulx podcast
|
|
5
|
+
Requires-Python: >=3.12
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: accelerate==1.10.1
|
|
8
|
+
Requires-Dist: diffusers==0.37.1
|
|
9
|
+
Requires-Dist: einops==0.8.2
|
|
10
|
+
Requires-Dist: librosa==0.11.0
|
|
11
|
+
Requires-Dist: numpy==2.4.6
|
|
12
|
+
Requires-Dist: onnxruntime==1.26.0
|
|
13
|
+
Requires-Dist: onnxruntime-gpu==1.26.0
|
|
14
|
+
Requires-Dist: s3tokenizer==0.3.0
|
|
15
|
+
Requires-Dist: scipy==1.17.1
|
|
16
|
+
Requires-Dist: sympy==1.13.1
|
|
17
|
+
Requires-Dist: torch==2.5.1+cu121
|
|
18
|
+
Requires-Dist: torchaudio==2.5.1+cu121
|
|
19
|
+
Requires-Dist: torchvision==0.20.1+cu121
|
|
20
|
+
Requires-Dist: transformers==4.57.1
|
|
File without changes
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "soulxpodcast"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "soulx podcast"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"accelerate==1.10.1",
|
|
9
|
+
"diffusers==0.37.1",
|
|
10
|
+
"einops==0.8.2",
|
|
11
|
+
"librosa==0.11.0",
|
|
12
|
+
"numpy==2.4.6",
|
|
13
|
+
"onnxruntime==1.26.0",
|
|
14
|
+
"onnxruntime-gpu==1.26.0",
|
|
15
|
+
"s3tokenizer==0.3.0",
|
|
16
|
+
"scipy==1.17.1",
|
|
17
|
+
"sympy==1.13.1",
|
|
18
|
+
"torch==2.5.1+cu121",
|
|
19
|
+
"torchaudio==2.5.1+cu121",
|
|
20
|
+
"torchvision==0.20.1+cu121",
|
|
21
|
+
"transformers==4.57.1",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
[[tool.uv.index]]
|
|
25
|
+
name = "pytorch-cu121"
|
|
26
|
+
url = "https://download.pytorch.org/whl/cu121"
|
|
27
|
+
explicit = true
|
|
28
|
+
|
|
29
|
+
[tool.uv.sources]
|
|
30
|
+
torch = { index = "pytorch-cu121" }
|
|
31
|
+
torchvision = { index = "pytorch-cu121" }
|
|
32
|
+
torchaudio = { index = "pytorch-cu121" }
|
|
File without changes
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass, field, fields, is_dataclass, asdict
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from transformers import AutoConfig
|
|
9
|
+
from transformers import PretrainedConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class SoulXPodcastLLMConfig:
|
|
14
|
+
architectures: list[str] = field(default_factory=lambda: ["Qwen3ForCausalLM"])
|
|
15
|
+
attention_dropout: float = 0.0
|
|
16
|
+
bos_token_id: int = 151643
|
|
17
|
+
eos_token_id: int = 151675 # speech eos
|
|
18
|
+
hidden_act: str = "silu"
|
|
19
|
+
hidden_size: int = 2048
|
|
20
|
+
initializer_range: float = 0.02
|
|
21
|
+
intermediate_size: int = 6144
|
|
22
|
+
max_position_embeddings: int = 40960
|
|
23
|
+
max_window_layers: int = 28
|
|
24
|
+
model_type: str = "qwen3"
|
|
25
|
+
num_attention_heads: int = 16
|
|
26
|
+
num_hidden_layers: int = 28
|
|
27
|
+
num_key_value_heads: int = 8
|
|
28
|
+
head_dim: int = 128
|
|
29
|
+
rms_norm_eps: float = 1e-06
|
|
30
|
+
rope_scaling: dict | None = None
|
|
31
|
+
rope_theta: float = 1000000.0
|
|
32
|
+
sliding_window: int = 32768
|
|
33
|
+
tie_word_embeddings: bool = True
|
|
34
|
+
torch_dtype: str = "bfloat16"
|
|
35
|
+
transformers_version: str = "4.52.3"
|
|
36
|
+
use_cache: bool = True
|
|
37
|
+
use_sliding_window: bool = False
|
|
38
|
+
vocab_size: int = 159488 # text_vocab_size + speech_vocab_size + 2 (eos and task_id)
|
|
39
|
+
lm_head_bias: bool = False
|
|
40
|
+
qkv_bias: bool = False
|
|
41
|
+
fp16_flow: bool = False
|
|
42
|
+
speech_token_offset: int = 152927
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_initial_and_json(
|
|
46
|
+
cls,
|
|
47
|
+
initial_values: Dict[str, Any] = None,
|
|
48
|
+
json_file: Optional[str] = None
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Create an instance from initial values and JSON data.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
initial_values: Dictionary of initial values (highest priority)
|
|
55
|
+
json_file: Path to JSON file
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
SoulXPodcastLLMConfig instance
|
|
59
|
+
"""
|
|
60
|
+
# Merge all data sources
|
|
61
|
+
merged_data = {}
|
|
62
|
+
|
|
63
|
+
# 1. Load from JSON file first (lowest priority)
|
|
64
|
+
if json_file and os.path.exists(json_file):
|
|
65
|
+
file_data = cls._load_json_file(json_file)
|
|
66
|
+
merged_data.update(file_data)
|
|
67
|
+
|
|
68
|
+
# 2. Overwrite with initial values (highest priority)
|
|
69
|
+
if initial_values:
|
|
70
|
+
merged_data.update(initial_values)
|
|
71
|
+
|
|
72
|
+
# Filter dataclass fields
|
|
73
|
+
valid_fields = {f.name for f in fields(cls)}
|
|
74
|
+
init_data = {k: v for k, v in merged_data.items() if k in valid_fields}
|
|
75
|
+
|
|
76
|
+
return cls(**init_data)
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _load_json_file(file_path: str) -> Dict[str, Any]:
|
|
80
|
+
"""Load data from a JSON file"""
|
|
81
|
+
path = Path(file_path)
|
|
82
|
+
if not path.exists():
|
|
83
|
+
return {}
|
|
84
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
85
|
+
return json.load(f)
|
|
86
|
+
|
|
87
|
+
class AutoPretrainedConfig(PretrainedConfig):
|
|
88
|
+
model_type = "qwen3"
|
|
89
|
+
|
|
90
|
+
def __init__(self, **kwargs):
|
|
91
|
+
# Filter out non-configuration parameters
|
|
92
|
+
config_kwargs = {k: v for k, v in kwargs.items()
|
|
93
|
+
if not k.startswith('_') and k != 'self'}
|
|
94
|
+
super().__init__(**config_kwargs)
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def from_dataclass(cls, dataclass_config):
|
|
98
|
+
"""Automatically create configuration from any dataclass"""
|
|
99
|
+
if not is_dataclass(dataclass_config):
|
|
100
|
+
raise ValueError("Input must be a dataclass instance")
|
|
101
|
+
|
|
102
|
+
dataclass_dict = asdict(dataclass_config)
|
|
103
|
+
return cls(**dataclass_dict)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class SamplingParams:
|
|
108
|
+
temperature: float = 0.6
|
|
109
|
+
repetition_penalty: float = 1.25
|
|
110
|
+
top_k: int = 100
|
|
111
|
+
top_p: float = 0.9
|
|
112
|
+
min_tokens: int = 8
|
|
113
|
+
max_tokens: int = 3000
|
|
114
|
+
stop_token_ids: list[int] = field(default_factory=lambda: [151675])
|
|
115
|
+
# RasSampler parameters
|
|
116
|
+
use_ras: bool = True
|
|
117
|
+
win_size: int = 25
|
|
118
|
+
tau_r: float = 0.2
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@dataclass
|
|
122
|
+
class Config:
|
|
123
|
+
model: str
|
|
124
|
+
max_model_len: int = 8192 # 15s prompt + 30s generated audio for 25hz audio tokenizer
|
|
125
|
+
gpu_memory_utilization: float = 0.9
|
|
126
|
+
tensor_parallel_size: int = 1
|
|
127
|
+
enforce_eager: bool = False
|
|
128
|
+
hf_config: SoulXPodcastLLMConfig | AutoConfig = field(default_factory=SoulXPodcastLLMConfig)
|
|
129
|
+
eos: int = -1
|
|
130
|
+
llm_engine: str = "hf" # support hf, nano-vllm
|
|
131
|
+
max_turn_size: int = 10
|
|
132
|
+
turn_tokens_threshold: int = 6192
|
|
133
|
+
|
|
134
|
+
prompt_context: int = 2 # default to 2 for two-speaker podcast;
|
|
135
|
+
history_context: int = 2
|
|
136
|
+
history_text_context: int = 2
|
|
137
|
+
|
|
138
|
+
def __post_init__(self):
|
|
139
|
+
assert os.path.isdir(self.model)
|
|
140
|
+
|
|
141
|
+
max_pos = getattr(self.hf_config, "max_position_embeddings", 8192)
|
|
142
|
+
self.max_model_len = min(self.max_model_len, max_pos)
|
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import types
|
|
3
|
+
import atexit
|
|
4
|
+
from time import perf_counter
|
|
5
|
+
from functools import partial
|
|
6
|
+
from dataclasses import fields, asdict
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.multiprocessing as mp
|
|
10
|
+
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList
|
|
11
|
+
from transformers import EosTokenCriteria, RepetitionPenaltyLogitsProcessor
|
|
12
|
+
try:
|
|
13
|
+
from vllm import LLM
|
|
14
|
+
from vllm import SamplingParams as VllmSamplingParams
|
|
15
|
+
from vllm.inputs import TokensPrompt as TokensPrompt
|
|
16
|
+
SUPPORT_VLLM = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
SUPPORT_VLLM = False
|
|
19
|
+
|
|
20
|
+
from soulxpodcast.config import Config, SamplingParams
|
|
21
|
+
from soulxpodcast.models.modules.sampler import _ras_sample_hf_engine
|
|
22
|
+
|
|
23
|
+
class HFLLMEngine:
|
|
24
|
+
|
|
25
|
+
def __init__(self, model, **kwargs):
|
|
26
|
+
config_fields = {field.name for field in fields(Config)}
|
|
27
|
+
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
|
28
|
+
config = Config(model, **config_kwargs)
|
|
29
|
+
|
|
30
|
+
self.tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
|
|
31
|
+
config.eos = config.hf_config.eos_token_id # speech eos token;
|
|
32
|
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
33
|
+
self.model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.bfloat16, device_map=self.device)
|
|
34
|
+
self.config = config
|
|
35
|
+
self.pad_token_id = self.tokenizer.pad_token_id
|
|
36
|
+
|
|
37
|
+
def generate(
|
|
38
|
+
self,
|
|
39
|
+
prompt: list[str],
|
|
40
|
+
sampling_param: SamplingParams,
|
|
41
|
+
past_key_values=None,
|
|
42
|
+
) -> dict:
|
|
43
|
+
|
|
44
|
+
stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.config.hf_config.eos_token_id)])
|
|
45
|
+
if sampling_param.use_ras:
|
|
46
|
+
sample_hf_engine_handler = partial(_ras_sample_hf_engine,
|
|
47
|
+
use_ras=sampling_param.use_ras,
|
|
48
|
+
win_size=sampling_param.win_size, tau_r=sampling_param.tau_r)
|
|
49
|
+
else:
|
|
50
|
+
sample_hf_engine_handler = None
|
|
51
|
+
rep_pen_processor = RepetitionPenaltyLogitsProcessor(
|
|
52
|
+
penalty=sampling_param.repetition_penalty,
|
|
53
|
+
prompt_ignore_length=len(prompt)
|
|
54
|
+
) # exclude the input prompt, consistent with vLLM implementation;
|
|
55
|
+
with torch.no_grad():
|
|
56
|
+
input_len = len(prompt)
|
|
57
|
+
generated_ids = self.model.generate(
|
|
58
|
+
input_ids = torch.tensor([prompt], dtype=torch.int64).to(self.device),
|
|
59
|
+
do_sample=True,
|
|
60
|
+
top_k=sampling_param.top_k,
|
|
61
|
+
top_p=sampling_param.top_p,
|
|
62
|
+
min_new_tokens=sampling_param.min_tokens,
|
|
63
|
+
max_new_tokens=sampling_param.max_tokens,
|
|
64
|
+
temperature=sampling_param.temperature,
|
|
65
|
+
stopping_criteria=stopping_criteria,
|
|
66
|
+
past_key_values=past_key_values,
|
|
67
|
+
custom_generate=sample_hf_engine_handler,
|
|
68
|
+
use_cache=True,
|
|
69
|
+
logits_processor=[rep_pen_processor]
|
|
70
|
+
)
|
|
71
|
+
generated_ids = generated_ids[:, input_len:].cpu().numpy().tolist()[0]
|
|
72
|
+
output = {
|
|
73
|
+
"text": self.tokenizer.decode(generated_ids),
|
|
74
|
+
"token_ids": generated_ids,
|
|
75
|
+
}
|
|
76
|
+
return output
|
|
77
|
+
|
|
78
|
+
class VLLMEngine:
|
|
79
|
+
|
|
80
|
+
def __init__(self, model, **kwargs):
|
|
81
|
+
|
|
82
|
+
config_fields = {field.name for field in fields(Config)}
|
|
83
|
+
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
|
84
|
+
config = Config(model, **config_kwargs)
|
|
85
|
+
|
|
86
|
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
|
87
|
+
config.eos = config.hf_config.eos_token_id # speech eos token;
|
|
88
|
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
89
|
+
os.environ["VLLM_USE_V1"] = "0"
|
|
90
|
+
if SUPPORT_VLLM:
|
|
91
|
+
self.model = LLM(model=model, enforce_eager=True, dtype="bfloat16", max_model_len=8192, enable_prefix_caching=True,)
|
|
92
|
+
else:
|
|
93
|
+
raise ImportError("Not Support VLLM now!!!")
|
|
94
|
+
self.config = config
|
|
95
|
+
self.pad_token_id = self.tokenizer.pad_token_id
|
|
96
|
+
|
|
97
|
+
def generate(
|
|
98
|
+
self,
|
|
99
|
+
prompt: list[str],
|
|
100
|
+
sampling_param: SamplingParams,
|
|
101
|
+
past_key_values=None,
|
|
102
|
+
) -> dict:
|
|
103
|
+
sampling_param.stop_token_ids = [self.config.hf_config.eos_token_id]
|
|
104
|
+
with torch.no_grad():
|
|
105
|
+
generated_ids = self.model.generate(
|
|
106
|
+
TokensPrompt(prompt_token_ids=prompt),
|
|
107
|
+
VllmSamplingParams(**asdict(sampling_param)),
|
|
108
|
+
use_tqdm=False,
|
|
109
|
+
)[0].outputs[0].token_ids
|
|
110
|
+
output = {
|
|
111
|
+
"text": self.tokenizer.decode(generated_ids),
|
|
112
|
+
"token_ids": list(generated_ids),
|
|
113
|
+
}
|
|
114
|
+
return output
|
|
File without changes
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from soulxpodcast.models.modules.flow_components.estimator import \
|
|
8
|
+
CausalConditionalDecoder
|
|
9
|
+
from soulxpodcast.models.modules.flow_components.upsample_encoder import (
|
|
10
|
+
UpsampleConformerEncoder, make_pad_mask)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class CfmParams:
|
|
15
|
+
sigma_min: float = 1e-6
|
|
16
|
+
solver: str = "euler"
|
|
17
|
+
t_scheduler: str = "cosine"
|
|
18
|
+
training_cfg_rate: float = 0.2
|
|
19
|
+
inference_cfg_rate: float = 0.7
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CausalConditionalCFM(torch.nn.Module):
|
|
23
|
+
def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.n_feats = in_channels
|
|
26
|
+
self.n_spks = n_spks
|
|
27
|
+
self.spk_emb_dim = spk_emb_dim
|
|
28
|
+
self.solver = cfm_params.solver
|
|
29
|
+
if hasattr(cfm_params, "sigma_min"):
|
|
30
|
+
self.sigma_min = cfm_params.sigma_min
|
|
31
|
+
else:
|
|
32
|
+
self.sigma_min = 1e-4
|
|
33
|
+
self.t_scheduler = cfm_params.t_scheduler
|
|
34
|
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
|
35
|
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
|
36
|
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
|
37
|
+
# Just change the architecture of the estimator here
|
|
38
|
+
self.estimator = CausalConditionalDecoder() if estimator is None else estimator
|
|
39
|
+
|
|
40
|
+
@torch.inference_mode()
|
|
41
|
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
|
42
|
+
"""Forward diffusion
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
mu (torch.Tensor): output of encoder
|
|
46
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
47
|
+
mask (torch.Tensor): output_mask
|
|
48
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
49
|
+
n_timesteps (int): number of diffusion steps
|
|
50
|
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
|
51
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
52
|
+
shape: (batch_size, spk_emb_dim)
|
|
53
|
+
cond: Not used but kept for future purposes
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
sample: generated mel-spectrogram
|
|
57
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
58
|
+
"""
|
|
59
|
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
|
60
|
+
# fix prompt and overlap part mu and z
|
|
61
|
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
|
62
|
+
if self.t_scheduler == 'cosine':
|
|
63
|
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
|
64
|
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
|
65
|
+
|
|
66
|
+
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
|
67
|
+
"""
|
|
68
|
+
Fixed euler solver for ODEs.
|
|
69
|
+
Args:
|
|
70
|
+
x (torch.Tensor): random noise
|
|
71
|
+
t_span (torch.Tensor): n_timesteps interpolated
|
|
72
|
+
shape: (n_timesteps + 1,)
|
|
73
|
+
mu (torch.Tensor): output of encoder
|
|
74
|
+
shape: (batch_size, n_feats, mel_timesteps)
|
|
75
|
+
mask (torch.Tensor): output_mask
|
|
76
|
+
shape: (batch_size, 1, mel_timesteps)
|
|
77
|
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
|
78
|
+
shape: (batch_size, spk_emb_dim)
|
|
79
|
+
cond: Not used but kept for future purposes
|
|
80
|
+
"""
|
|
81
|
+
batch_size = x.size(0)
|
|
82
|
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
83
|
+
|
|
84
|
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
|
85
|
+
# Or in future might add like a return_all_steps flag
|
|
86
|
+
sol = []
|
|
87
|
+
|
|
88
|
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
|
89
|
+
# Create tensors with double batch size for CFG (conditional + unconditional)
|
|
90
|
+
x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype)
|
|
91
|
+
mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype)
|
|
92
|
+
mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype)
|
|
93
|
+
t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype)
|
|
94
|
+
spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype)
|
|
95
|
+
cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype)
|
|
96
|
+
|
|
97
|
+
for step in range(1, len(t_span)):
|
|
98
|
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
|
99
|
+
# Copy conditional and unconditional input
|
|
100
|
+
x_in[:batch_size] = x
|
|
101
|
+
x_in[batch_size:] = x
|
|
102
|
+
mask_in[:batch_size] = mask
|
|
103
|
+
mask_in[batch_size:] = mask
|
|
104
|
+
mu_in[:batch_size] = mu
|
|
105
|
+
# Unconditional part remains 0
|
|
106
|
+
t_in.fill_(t)
|
|
107
|
+
spks_in[:batch_size] = spks
|
|
108
|
+
cond_in[:batch_size] = cond
|
|
109
|
+
|
|
110
|
+
dphi_dt = self.estimator(
|
|
111
|
+
x_in, mask_in,
|
|
112
|
+
mu_in, t_in,
|
|
113
|
+
spks_in,
|
|
114
|
+
cond_in,
|
|
115
|
+
streaming
|
|
116
|
+
)
|
|
117
|
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0)
|
|
118
|
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
|
119
|
+
x = x + dt * dphi_dt
|
|
120
|
+
t = t + dt
|
|
121
|
+
sol.append(x)
|
|
122
|
+
if step < len(t_span) - 1:
|
|
123
|
+
dt = t_span[step + 1] - t
|
|
124
|
+
|
|
125
|
+
return sol[-1].float()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
input_size: int = 512,
|
|
132
|
+
output_size: int = 80,
|
|
133
|
+
spk_embed_dim: int = 192,
|
|
134
|
+
output_type: str = "mel",
|
|
135
|
+
vocab_size: int = 6561,
|
|
136
|
+
input_frame_rate: int = 25,
|
|
137
|
+
token_mel_ratio: int = 2,
|
|
138
|
+
pre_lookahead_len: int = 3,
|
|
139
|
+
encoder: torch.nn.Module = None,
|
|
140
|
+
decoder: torch.nn.Module = None,
|
|
141
|
+
):
|
|
142
|
+
super().__init__()
|
|
143
|
+
self.input_size = input_size
|
|
144
|
+
self.output_size = output_size
|
|
145
|
+
self.vocab_size = vocab_size
|
|
146
|
+
self.output_type = output_type
|
|
147
|
+
self.input_frame_rate = input_frame_rate
|
|
148
|
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
|
149
|
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
|
150
|
+
self.encoder = UpsampleConformerEncoder() if encoder is None else encoder
|
|
151
|
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
|
152
|
+
self.decoder = CausalConditionalCFM() if decoder is None else decoder
|
|
153
|
+
self.token_mel_ratio = token_mel_ratio
|
|
154
|
+
self.pre_lookahead_len = pre_lookahead_len
|
|
155
|
+
|
|
156
|
+
@torch.inference_mode()
|
|
157
|
+
def forward(self,
|
|
158
|
+
token,
|
|
159
|
+
token_len,
|
|
160
|
+
prompt_feat,
|
|
161
|
+
prompt_feat_len,
|
|
162
|
+
embedding,
|
|
163
|
+
streaming,
|
|
164
|
+
finalize):
|
|
165
|
+
# xvec projection
|
|
166
|
+
embedding = F.normalize(embedding, dim=1)
|
|
167
|
+
embedding = self.spk_embed_affine_layer(embedding)
|
|
168
|
+
|
|
169
|
+
# concat text and prompt_text
|
|
170
|
+
mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding)
|
|
171
|
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
|
172
|
+
|
|
173
|
+
# text encode
|
|
174
|
+
if finalize is True:
|
|
175
|
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
|
176
|
+
else:
|
|
177
|
+
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
|
178
|
+
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
|
179
|
+
h = self.encoder_proj(h)
|
|
180
|
+
|
|
181
|
+
# get conditions
|
|
182
|
+
conds = torch.zeros_like(h, device=token.device)
|
|
183
|
+
for i, j in enumerate(prompt_feat_len):
|
|
184
|
+
conds[i, :j] = prompt_feat[i, :j]
|
|
185
|
+
conds = conds.transpose(1, 2)
|
|
186
|
+
|
|
187
|
+
h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
|
|
188
|
+
mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)
|
|
189
|
+
feat, _ = self.decoder(
|
|
190
|
+
mu=h.transpose(1, 2).contiguous(),
|
|
191
|
+
mask=mask.unsqueeze(1),
|
|
192
|
+
spks=embedding,
|
|
193
|
+
cond=conds,
|
|
194
|
+
n_timesteps=15,
|
|
195
|
+
streaming=streaming
|
|
196
|
+
) # [B, num_mels, T]
|
|
197
|
+
return feat.float(), h_lengths
|
|
File without changes
|