soulxpodcast 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. soulxpodcast-0.1.0/PKG-INFO +20 -0
  2. soulxpodcast-0.1.0/README.md +0 -0
  3. soulxpodcast-0.1.0/pyproject.toml +32 -0
  4. soulxpodcast-0.1.0/setup.cfg +4 -0
  5. soulxpodcast-0.1.0/src/soulxpodcast/__init__.py +0 -0
  6. soulxpodcast-0.1.0/src/soulxpodcast/config.py +142 -0
  7. soulxpodcast-0.1.0/src/soulxpodcast/engine/__init__.py +0 -0
  8. soulxpodcast-0.1.0/src/soulxpodcast/engine/llm_engine.py +114 -0
  9. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/__init__.py +0 -0
  10. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/flow.py +197 -0
  11. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/flow_components/__init__.py +0 -0
  12. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/flow_components/estimator.py +974 -0
  13. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/flow_components/upsample_encoder.py +997 -0
  14. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/hifigan.py +249 -0
  15. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/hifigan_components/__init__.py +0 -0
  16. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/hifigan_components/layers.py +433 -0
  17. soulxpodcast-0.1.0/src/soulxpodcast/models/modules/sampler.py +221 -0
  18. soulxpodcast-0.1.0/src/soulxpodcast/models/soulxpodcast.py +168 -0
  19. soulxpodcast-0.1.0/src/soulxpodcast/utils/__init__.py +0 -0
  20. soulxpodcast-0.1.0/src/soulxpodcast/utils/audio.py +123 -0
  21. soulxpodcast-0.1.0/src/soulxpodcast/utils/commons.py +10 -0
  22. soulxpodcast-0.1.0/src/soulxpodcast/utils/dataloader.py +198 -0
  23. soulxpodcast-0.1.0/src/soulxpodcast/utils/infer_utils.py +95 -0
  24. soulxpodcast-0.1.0/src/soulxpodcast/utils/parser.py +87 -0
  25. soulxpodcast-0.1.0/src/soulxpodcast/utils/text.py +82 -0
  26. soulxpodcast-0.1.0/src/soulxpodcast.egg-info/PKG-INFO +20 -0
  27. soulxpodcast-0.1.0/src/soulxpodcast.egg-info/SOURCES.txt +28 -0
  28. soulxpodcast-0.1.0/src/soulxpodcast.egg-info/dependency_links.txt +1 -0
  29. soulxpodcast-0.1.0/src/soulxpodcast.egg-info/requires.txt +14 -0
  30. soulxpodcast-0.1.0/src/soulxpodcast.egg-info/top_level.txt +1 -0
@@ -0,0 +1,20 @@
1
+ Metadata-Version: 2.4
2
+ Name: soulxpodcast
3
+ Version: 0.1.0
4
+ Summary: soulx podcast
5
+ Requires-Python: >=3.12
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: accelerate==1.10.1
8
+ Requires-Dist: diffusers==0.37.1
9
+ Requires-Dist: einops==0.8.2
10
+ Requires-Dist: librosa==0.11.0
11
+ Requires-Dist: numpy==2.4.6
12
+ Requires-Dist: onnxruntime==1.26.0
13
+ Requires-Dist: onnxruntime-gpu==1.26.0
14
+ Requires-Dist: s3tokenizer==0.3.0
15
+ Requires-Dist: scipy==1.17.1
16
+ Requires-Dist: sympy==1.13.1
17
+ Requires-Dist: torch==2.5.1+cu121
18
+ Requires-Dist: torchaudio==2.5.1+cu121
19
+ Requires-Dist: torchvision==0.20.1+cu121
20
+ Requires-Dist: transformers==4.57.1
File without changes
@@ -0,0 +1,32 @@
1
+ [project]
2
+ name = "soulxpodcast"
3
+ version = "0.1.0"
4
+ description = "soulx podcast"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "accelerate==1.10.1",
9
+ "diffusers==0.37.1",
10
+ "einops==0.8.2",
11
+ "librosa==0.11.0",
12
+ "numpy==2.4.6",
13
+ "onnxruntime==1.26.0",
14
+ "onnxruntime-gpu==1.26.0",
15
+ "s3tokenizer==0.3.0",
16
+ "scipy==1.17.1",
17
+ "sympy==1.13.1",
18
+ "torch==2.5.1+cu121",
19
+ "torchaudio==2.5.1+cu121",
20
+ "torchvision==0.20.1+cu121",
21
+ "transformers==4.57.1",
22
+ ]
23
+
24
+ [[tool.uv.index]]
25
+ name = "pytorch-cu121"
26
+ url = "https://download.pytorch.org/whl/cu121"
27
+ explicit = true
28
+
29
+ [tool.uv.sources]
30
+ torch = { index = "pytorch-cu121" }
31
+ torchvision = { index = "pytorch-cu121" }
32
+ torchaudio = { index = "pytorch-cu121" }
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
File without changes
@@ -0,0 +1,142 @@
1
+ import os
2
+ from dataclasses import dataclass, field, fields, is_dataclass, asdict
3
+ from typing import Any, Dict, List, Optional
4
+ from pathlib import Path
5
+ import json
6
+
7
+ import torch
8
+ from transformers import AutoConfig
9
+ from transformers import PretrainedConfig
10
+
11
+
12
+ @dataclass
13
+ class SoulXPodcastLLMConfig:
14
+ architectures: list[str] = field(default_factory=lambda: ["Qwen3ForCausalLM"])
15
+ attention_dropout: float = 0.0
16
+ bos_token_id: int = 151643
17
+ eos_token_id: int = 151675 # speech eos
18
+ hidden_act: str = "silu"
19
+ hidden_size: int = 2048
20
+ initializer_range: float = 0.02
21
+ intermediate_size: int = 6144
22
+ max_position_embeddings: int = 40960
23
+ max_window_layers: int = 28
24
+ model_type: str = "qwen3"
25
+ num_attention_heads: int = 16
26
+ num_hidden_layers: int = 28
27
+ num_key_value_heads: int = 8
28
+ head_dim: int = 128
29
+ rms_norm_eps: float = 1e-06
30
+ rope_scaling: dict | None = None
31
+ rope_theta: float = 1000000.0
32
+ sliding_window: int = 32768
33
+ tie_word_embeddings: bool = True
34
+ torch_dtype: str = "bfloat16"
35
+ transformers_version: str = "4.52.3"
36
+ use_cache: bool = True
37
+ use_sliding_window: bool = False
38
+ vocab_size: int = 159488 # text_vocab_size + speech_vocab_size + 2 (eos and task_id)
39
+ lm_head_bias: bool = False
40
+ qkv_bias: bool = False
41
+ fp16_flow: bool = False
42
+ speech_token_offset: int = 152927
43
+
44
+ @classmethod
45
+ def from_initial_and_json(
46
+ cls,
47
+ initial_values: Dict[str, Any] = None,
48
+ json_file: Optional[str] = None
49
+ ):
50
+ """
51
+ Create an instance from initial values and JSON data.
52
+
53
+ Args:
54
+ initial_values: Dictionary of initial values (highest priority)
55
+ json_file: Path to JSON file
56
+
57
+ Returns:
58
+ SoulXPodcastLLMConfig instance
59
+ """
60
+ # Merge all data sources
61
+ merged_data = {}
62
+
63
+ # 1. Load from JSON file first (lowest priority)
64
+ if json_file and os.path.exists(json_file):
65
+ file_data = cls._load_json_file(json_file)
66
+ merged_data.update(file_data)
67
+
68
+ # 2. Overwrite with initial values (highest priority)
69
+ if initial_values:
70
+ merged_data.update(initial_values)
71
+
72
+ # Filter dataclass fields
73
+ valid_fields = {f.name for f in fields(cls)}
74
+ init_data = {k: v for k, v in merged_data.items() if k in valid_fields}
75
+
76
+ return cls(**init_data)
77
+
78
+ @staticmethod
79
+ def _load_json_file(file_path: str) -> Dict[str, Any]:
80
+ """Load data from a JSON file"""
81
+ path = Path(file_path)
82
+ if not path.exists():
83
+ return {}
84
+ with open(path, 'r', encoding='utf-8') as f:
85
+ return json.load(f)
86
+
87
+ class AutoPretrainedConfig(PretrainedConfig):
88
+ model_type = "qwen3"
89
+
90
+ def __init__(self, **kwargs):
91
+ # Filter out non-configuration parameters
92
+ config_kwargs = {k: v for k, v in kwargs.items()
93
+ if not k.startswith('_') and k != 'self'}
94
+ super().__init__(**config_kwargs)
95
+
96
+ @classmethod
97
+ def from_dataclass(cls, dataclass_config):
98
+ """Automatically create configuration from any dataclass"""
99
+ if not is_dataclass(dataclass_config):
100
+ raise ValueError("Input must be a dataclass instance")
101
+
102
+ dataclass_dict = asdict(dataclass_config)
103
+ return cls(**dataclass_dict)
104
+
105
+
106
+ @dataclass
107
+ class SamplingParams:
108
+ temperature: float = 0.6
109
+ repetition_penalty: float = 1.25
110
+ top_k: int = 100
111
+ top_p: float = 0.9
112
+ min_tokens: int = 8
113
+ max_tokens: int = 3000
114
+ stop_token_ids: list[int] = field(default_factory=lambda: [151675])
115
+ # RasSampler parameters
116
+ use_ras: bool = True
117
+ win_size: int = 25
118
+ tau_r: float = 0.2
119
+
120
+
121
+ @dataclass
122
+ class Config:
123
+ model: str
124
+ max_model_len: int = 8192 # 15s prompt + 30s generated audio for 25hz audio tokenizer
125
+ gpu_memory_utilization: float = 0.9
126
+ tensor_parallel_size: int = 1
127
+ enforce_eager: bool = False
128
+ hf_config: SoulXPodcastLLMConfig | AutoConfig = field(default_factory=SoulXPodcastLLMConfig)
129
+ eos: int = -1
130
+ llm_engine: str = "hf" # support hf, nano-vllm
131
+ max_turn_size: int = 10
132
+ turn_tokens_threshold: int = 6192
133
+
134
+ prompt_context: int = 2 # default to 2 for two-speaker podcast;
135
+ history_context: int = 2
136
+ history_text_context: int = 2
137
+
138
+ def __post_init__(self):
139
+ assert os.path.isdir(self.model)
140
+
141
+ max_pos = getattr(self.hf_config, "max_position_embeddings", 8192)
142
+ self.max_model_len = min(self.max_model_len, max_pos)
File without changes
@@ -0,0 +1,114 @@
1
+ import os
2
+ import types
3
+ import atexit
4
+ from time import perf_counter
5
+ from functools import partial
6
+ from dataclasses import fields, asdict
7
+
8
+ import torch
9
+ import torch.multiprocessing as mp
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList
11
+ from transformers import EosTokenCriteria, RepetitionPenaltyLogitsProcessor
12
+ try:
13
+ from vllm import LLM
14
+ from vllm import SamplingParams as VllmSamplingParams
15
+ from vllm.inputs import TokensPrompt as TokensPrompt
16
+ SUPPORT_VLLM = True
17
+ except ImportError:
18
+ SUPPORT_VLLM = False
19
+
20
+ from soulxpodcast.config import Config, SamplingParams
21
+ from soulxpodcast.models.modules.sampler import _ras_sample_hf_engine
22
+
23
+ class HFLLMEngine:
24
+
25
+ def __init__(self, model, **kwargs):
26
+ config_fields = {field.name for field in fields(Config)}
27
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
28
+ config = Config(model, **config_kwargs)
29
+
30
+ self.tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
31
+ config.eos = config.hf_config.eos_token_id # speech eos token;
32
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
33
+ self.model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.bfloat16, device_map=self.device)
34
+ self.config = config
35
+ self.pad_token_id = self.tokenizer.pad_token_id
36
+
37
+ def generate(
38
+ self,
39
+ prompt: list[str],
40
+ sampling_param: SamplingParams,
41
+ past_key_values=None,
42
+ ) -> dict:
43
+
44
+ stopping_criteria = StoppingCriteriaList([EosTokenCriteria(eos_token_id=self.config.hf_config.eos_token_id)])
45
+ if sampling_param.use_ras:
46
+ sample_hf_engine_handler = partial(_ras_sample_hf_engine,
47
+ use_ras=sampling_param.use_ras,
48
+ win_size=sampling_param.win_size, tau_r=sampling_param.tau_r)
49
+ else:
50
+ sample_hf_engine_handler = None
51
+ rep_pen_processor = RepetitionPenaltyLogitsProcessor(
52
+ penalty=sampling_param.repetition_penalty,
53
+ prompt_ignore_length=len(prompt)
54
+ ) # exclude the input prompt, consistent with vLLM implementation;
55
+ with torch.no_grad():
56
+ input_len = len(prompt)
57
+ generated_ids = self.model.generate(
58
+ input_ids = torch.tensor([prompt], dtype=torch.int64).to(self.device),
59
+ do_sample=True,
60
+ top_k=sampling_param.top_k,
61
+ top_p=sampling_param.top_p,
62
+ min_new_tokens=sampling_param.min_tokens,
63
+ max_new_tokens=sampling_param.max_tokens,
64
+ temperature=sampling_param.temperature,
65
+ stopping_criteria=stopping_criteria,
66
+ past_key_values=past_key_values,
67
+ custom_generate=sample_hf_engine_handler,
68
+ use_cache=True,
69
+ logits_processor=[rep_pen_processor]
70
+ )
71
+ generated_ids = generated_ids[:, input_len:].cpu().numpy().tolist()[0]
72
+ output = {
73
+ "text": self.tokenizer.decode(generated_ids),
74
+ "token_ids": generated_ids,
75
+ }
76
+ return output
77
+
78
+ class VLLMEngine:
79
+
80
+ def __init__(self, model, **kwargs):
81
+
82
+ config_fields = {field.name for field in fields(Config)}
83
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
84
+ config = Config(model, **config_kwargs)
85
+
86
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
87
+ config.eos = config.hf_config.eos_token_id # speech eos token;
88
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
89
+ os.environ["VLLM_USE_V1"] = "0"
90
+ if SUPPORT_VLLM:
91
+ self.model = LLM(model=model, enforce_eager=True, dtype="bfloat16", max_model_len=8192, enable_prefix_caching=True,)
92
+ else:
93
+ raise ImportError("Not Support VLLM now!!!")
94
+ self.config = config
95
+ self.pad_token_id = self.tokenizer.pad_token_id
96
+
97
+ def generate(
98
+ self,
99
+ prompt: list[str],
100
+ sampling_param: SamplingParams,
101
+ past_key_values=None,
102
+ ) -> dict:
103
+ sampling_param.stop_token_ids = [self.config.hf_config.eos_token_id]
104
+ with torch.no_grad():
105
+ generated_ids = self.model.generate(
106
+ TokensPrompt(prompt_token_ids=prompt),
107
+ VllmSamplingParams(**asdict(sampling_param)),
108
+ use_tqdm=False,
109
+ )[0].outputs[0].token_ids
110
+ output = {
111
+ "text": self.tokenizer.decode(generated_ids),
112
+ "token_ids": list(generated_ids),
113
+ }
114
+ return output
@@ -0,0 +1,197 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from soulxpodcast.models.modules.flow_components.estimator import \
8
+ CausalConditionalDecoder
9
+ from soulxpodcast.models.modules.flow_components.upsample_encoder import (
10
+ UpsampleConformerEncoder, make_pad_mask)
11
+
12
+
13
+ @dataclass
14
+ class CfmParams:
15
+ sigma_min: float = 1e-6
16
+ solver: str = "euler"
17
+ t_scheduler: str = "cosine"
18
+ training_cfg_rate: float = 0.2
19
+ inference_cfg_rate: float = 0.7
20
+
21
+
22
+ class CausalConditionalCFM(torch.nn.Module):
23
+ def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None):
24
+ super().__init__()
25
+ self.n_feats = in_channels
26
+ self.n_spks = n_spks
27
+ self.spk_emb_dim = spk_emb_dim
28
+ self.solver = cfm_params.solver
29
+ if hasattr(cfm_params, "sigma_min"):
30
+ self.sigma_min = cfm_params.sigma_min
31
+ else:
32
+ self.sigma_min = 1e-4
33
+ self.t_scheduler = cfm_params.t_scheduler
34
+ self.training_cfg_rate = cfm_params.training_cfg_rate
35
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
36
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
37
+ # Just change the architecture of the estimator here
38
+ self.estimator = CausalConditionalDecoder() if estimator is None else estimator
39
+
40
+ @torch.inference_mode()
41
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
42
+ """Forward diffusion
43
+
44
+ Args:
45
+ mu (torch.Tensor): output of encoder
46
+ shape: (batch_size, n_feats, mel_timesteps)
47
+ mask (torch.Tensor): output_mask
48
+ shape: (batch_size, 1, mel_timesteps)
49
+ n_timesteps (int): number of diffusion steps
50
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
51
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
52
+ shape: (batch_size, spk_emb_dim)
53
+ cond: Not used but kept for future purposes
54
+
55
+ Returns:
56
+ sample: generated mel-spectrogram
57
+ shape: (batch_size, n_feats, mel_timesteps)
58
+ """
59
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
60
+ # fix prompt and overlap part mu and z
61
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
62
+ if self.t_scheduler == 'cosine':
63
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
64
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
65
+
66
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
67
+ """
68
+ Fixed euler solver for ODEs.
69
+ Args:
70
+ x (torch.Tensor): random noise
71
+ t_span (torch.Tensor): n_timesteps interpolated
72
+ shape: (n_timesteps + 1,)
73
+ mu (torch.Tensor): output of encoder
74
+ shape: (batch_size, n_feats, mel_timesteps)
75
+ mask (torch.Tensor): output_mask
76
+ shape: (batch_size, 1, mel_timesteps)
77
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
78
+ shape: (batch_size, spk_emb_dim)
79
+ cond: Not used but kept for future purposes
80
+ """
81
+ batch_size = x.size(0)
82
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
83
+
84
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
85
+ # Or in future might add like a return_all_steps flag
86
+ sol = []
87
+
88
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
89
+ # Create tensors with double batch size for CFG (conditional + unconditional)
90
+ x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype)
91
+ mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype)
92
+ mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype)
93
+ t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype)
94
+ spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype)
95
+ cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype)
96
+
97
+ for step in range(1, len(t_span)):
98
+ # Classifier-Free Guidance inference introduced in VoiceBox
99
+ # Copy conditional and unconditional input
100
+ x_in[:batch_size] = x
101
+ x_in[batch_size:] = x
102
+ mask_in[:batch_size] = mask
103
+ mask_in[batch_size:] = mask
104
+ mu_in[:batch_size] = mu
105
+ # Unconditional part remains 0
106
+ t_in.fill_(t)
107
+ spks_in[:batch_size] = spks
108
+ cond_in[:batch_size] = cond
109
+
110
+ dphi_dt = self.estimator(
111
+ x_in, mask_in,
112
+ mu_in, t_in,
113
+ spks_in,
114
+ cond_in,
115
+ streaming
116
+ )
117
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0)
118
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
119
+ x = x + dt * dphi_dt
120
+ t = t + dt
121
+ sol.append(x)
122
+ if step < len(t_span) - 1:
123
+ dt = t_span[step + 1] - t
124
+
125
+ return sol[-1].float()
126
+
127
+
128
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
129
+ def __init__(
130
+ self,
131
+ input_size: int = 512,
132
+ output_size: int = 80,
133
+ spk_embed_dim: int = 192,
134
+ output_type: str = "mel",
135
+ vocab_size: int = 6561,
136
+ input_frame_rate: int = 25,
137
+ token_mel_ratio: int = 2,
138
+ pre_lookahead_len: int = 3,
139
+ encoder: torch.nn.Module = None,
140
+ decoder: torch.nn.Module = None,
141
+ ):
142
+ super().__init__()
143
+ self.input_size = input_size
144
+ self.output_size = output_size
145
+ self.vocab_size = vocab_size
146
+ self.output_type = output_type
147
+ self.input_frame_rate = input_frame_rate
148
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
149
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
150
+ self.encoder = UpsampleConformerEncoder() if encoder is None else encoder
151
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
152
+ self.decoder = CausalConditionalCFM() if decoder is None else decoder
153
+ self.token_mel_ratio = token_mel_ratio
154
+ self.pre_lookahead_len = pre_lookahead_len
155
+
156
+ @torch.inference_mode()
157
+ def forward(self,
158
+ token,
159
+ token_len,
160
+ prompt_feat,
161
+ prompt_feat_len,
162
+ embedding,
163
+ streaming,
164
+ finalize):
165
+ # xvec projection
166
+ embedding = F.normalize(embedding, dim=1)
167
+ embedding = self.spk_embed_affine_layer(embedding)
168
+
169
+ # concat text and prompt_text
170
+ mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding)
171
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
172
+
173
+ # text encode
174
+ if finalize is True:
175
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
176
+ else:
177
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
178
+ h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
179
+ h = self.encoder_proj(h)
180
+
181
+ # get conditions
182
+ conds = torch.zeros_like(h, device=token.device)
183
+ for i, j in enumerate(prompt_feat_len):
184
+ conds[i, :j] = prompt_feat[i, :j]
185
+ conds = conds.transpose(1, 2)
186
+
187
+ h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
188
+ mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)
189
+ feat, _ = self.decoder(
190
+ mu=h.transpose(1, 2).contiguous(),
191
+ mask=mask.unsqueeze(1),
192
+ spks=embedding,
193
+ cond=conds,
194
+ n_timesteps=15,
195
+ streaming=streaming
196
+ ) # [B, num_mels, T]
197
+ return feat.float(), h_lengths