ttsd-colabcli 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cli.js +148 -0
- package/core/app/__init__.py +0 -0
- package/core/app/colab_cli/__init__.py +0 -0
- package/core/app/colab_cli/__pycache__/__init__.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/auth.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/auto_update.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/cli.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/client.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/common.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/console.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/contents.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/history.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/runtime.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/state.cpython-312.pyc +0 -0
- package/core/app/colab_cli/__pycache__/utils.cpython-312.pyc +0 -0
- package/core/app/colab_cli/auth.py +278 -0
- package/core/app/colab_cli/auto_update.py +248 -0
- package/core/app/colab_cli/cli.py +155 -0
- package/core/app/colab_cli/client.py +310 -0
- package/core/app/colab_cli/commands/__init__.py +14 -0
- package/core/app/colab_cli/commands/__pycache__/__init__.cpython-312.pyc +0 -0
- package/core/app/colab_cli/commands/__pycache__/automation.cpython-312.pyc +0 -0
- package/core/app/colab_cli/commands/__pycache__/execution.cpython-312.pyc +0 -0
- package/core/app/colab_cli/commands/__pycache__/files.cpython-312.pyc +0 -0
- package/core/app/colab_cli/commands/__pycache__/run.cpython-312.pyc +0 -0
- package/core/app/colab_cli/commands/__pycache__/session.cpython-312.pyc +0 -0
- package/core/app/colab_cli/commands/__pycache__/utility.cpython-312.pyc +0 -0
- package/core/app/colab_cli/commands/automation.py +265 -0
- package/core/app/colab_cli/commands/execution.py +362 -0
- package/core/app/colab_cli/commands/files.py +204 -0
- package/core/app/colab_cli/commands/run.py +477 -0
- package/core/app/colab_cli/commands/session.py +519 -0
- package/core/app/colab_cli/commands/utility.py +436 -0
- package/core/app/colab_cli/common.py +185 -0
- package/core/app/colab_cli/console.py +172 -0
- package/core/app/colab_cli/contents.py +93 -0
- package/core/app/colab_cli/converter.py +184 -0
- package/core/app/colab_cli/history.py +65 -0
- package/core/app/colab_cli/oauth_config.json +11 -0
- package/core/app/colab_cli/repl.py +173 -0
- package/core/app/colab_cli/runtime.py +262 -0
- package/core/app/colab_cli/state.py +156 -0
- package/core/app/colab_cli/utils.py +85 -0
- package/core/colab/worker.py +679 -0
- package/core/daemon.py +184 -0
- package/core/requirements.txt +8 -0
- package/package.json +22 -0
|
@@ -0,0 +1,679 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import asyncio
|
|
6
|
+
import hashlib
|
|
7
|
+
import io
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
import sys
|
|
12
|
+
import time
|
|
13
|
+
import gc
|
|
14
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
15
|
+
|
|
16
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
17
|
+
logger = logging.getLogger("worker")
|
|
18
|
+
from contextlib import nullcontext
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import httpx
|
|
23
|
+
import soundfile as sf
|
|
24
|
+
import torch
|
|
25
|
+
import websockets
|
|
26
|
+
|
|
27
|
+
SAMPLE_RATE = 24000
|
|
28
|
+
REF_CACHE_DIR = Path("/tmp/omnivoice_refs")
|
|
29
|
+
MODEL_ID = "k2-fsa/OmniVoice"
|
|
30
|
+
TASK_QUEUE_MAXSIZE = 16
|
|
31
|
+
OMNIVOICE_NUM_STEP = int(os.getenv("OMNIVOICE_NUM_STEP", "16")) # Lower steps = faster, 32=best quality
|
|
32
|
+
OMNIVOICE_GUIDANCE_SCALE = float(os.getenv("OMNIVOICE_GUIDANCE_SCALE", "4.0")) # Best voice guidance, higher = tighter voice clone
|
|
33
|
+
_REF_MAX_RAW = float(os.getenv("REF_AUDIO_MAX_SECONDS", "15")) # Increased to 15s for better clone quality
|
|
34
|
+
REF_AUDIO_MAX_SECONDS = max(1.0, min(30.0, _REF_MAX_RAW))
|
|
35
|
+
if REF_AUDIO_MAX_SECONDS != _REF_MAX_RAW:
|
|
36
|
+
print(f"[ref] REF_AUDIO_MAX_SECONDS clamped to {REF_AUDIO_MAX_SECONDS:.1f}s (was {_REF_MAX_RAW})", flush=True)
|
|
37
|
+
OMNIVOICE_SPEED = float(os.getenv("OMNIVOICE_SPEED", "1.0"))
|
|
38
|
+
|
|
39
|
+
executor = ThreadPoolExecutor(max_workers=1)
|
|
40
|
+
_voice_prompt_cache: dict[str, Any] = {}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def configure_torch_runtime() -> None:
|
|
44
|
+
"""Tune PyTorch defaults for Colab T4 inference without changing model API."""
|
|
45
|
+
try:
|
|
46
|
+
torch.set_grad_enabled(False)
|
|
47
|
+
except Exception:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
if not torch.cuda.is_available():
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
torch.backends.cudnn.benchmark = True
|
|
55
|
+
except Exception:
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
60
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
61
|
+
except Exception:
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
torch.set_float32_matmul_precision("high")
|
|
66
|
+
except Exception:
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def autocast_context():
|
|
71
|
+
if torch.cuda.is_available():
|
|
72
|
+
return torch.autocast(device_type="cuda", dtype=torch.float16)
|
|
73
|
+
return nullcontext()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def normalize_server_url(server_url: str) -> str:
|
|
77
|
+
normalized = (server_url or "").strip().rstrip("/")
|
|
78
|
+
if not normalized.startswith(("http://", "https://")):
|
|
79
|
+
raise ValueError(f"SERVER_URL không hợp lệ: {server_url!r}")
|
|
80
|
+
return normalized
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def websocket_url(server_url: str) -> str:
|
|
84
|
+
if server_url.startswith("https://"):
|
|
85
|
+
return "wss://" + server_url.removeprefix("https://") + "/ws/worker"
|
|
86
|
+
return "ws://" + server_url.removeprefix("http://") + "/ws/worker"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def detect_device() -> str:
|
|
90
|
+
configure_torch_runtime()
|
|
91
|
+
if torch.cuda.is_available():
|
|
92
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
93
|
+
total_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
94
|
+
print(f"✅ GPU: {gpu_name} | VRAM={total_gb:.1f}GB", flush=True)
|
|
95
|
+
return "cuda:0"
|
|
96
|
+
print("⚠️ KHÔNG có GPU! Chạy trên CPU.", flush=True)
|
|
97
|
+
return "cpu"
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def cache_generate_signature(model: Any) -> set[str]:
|
|
101
|
+
import inspect
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
params = set(inspect.signature(model.generate).parameters.keys())
|
|
105
|
+
except Exception:
|
|
106
|
+
params = set()
|
|
107
|
+
model._omnivoice_generate_params = params
|
|
108
|
+
print(f"[model] generate params cached: {sorted(params)}", flush=True)
|
|
109
|
+
return params
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def build_generate_kwargs(
|
|
113
|
+
params: set[str],
|
|
114
|
+
text: str,
|
|
115
|
+
ref_audio: str,
|
|
116
|
+
ref_text: str | None = None,
|
|
117
|
+
language: str | None = None,
|
|
118
|
+
num_step: int | None = None,
|
|
119
|
+
guidance_scale: float | None = None,
|
|
120
|
+
) -> dict[str, Any]:
|
|
121
|
+
kwargs: dict[str, Any] = {}
|
|
122
|
+
|
|
123
|
+
if "text" in params:
|
|
124
|
+
kwargs["text"] = text
|
|
125
|
+
elif "prompt" in params:
|
|
126
|
+
kwargs["prompt"] = text
|
|
127
|
+
elif "input_text" in params:
|
|
128
|
+
kwargs["input_text"] = text
|
|
129
|
+
else:
|
|
130
|
+
kwargs["text"] = text
|
|
131
|
+
|
|
132
|
+
if "ref_audio" in params:
|
|
133
|
+
kwargs["ref_audio"] = ref_audio
|
|
134
|
+
elif "reference_audio" in params:
|
|
135
|
+
kwargs["reference_audio"] = ref_audio
|
|
136
|
+
elif "reference_wav" in params:
|
|
137
|
+
kwargs["reference_wav"] = ref_audio
|
|
138
|
+
elif "prompt_audio" in params:
|
|
139
|
+
kwargs["prompt_audio"] = ref_audio
|
|
140
|
+
else:
|
|
141
|
+
kwargs["ref_audio"] = ref_audio
|
|
142
|
+
|
|
143
|
+
if ref_text:
|
|
144
|
+
if "ref_text" in params:
|
|
145
|
+
kwargs["ref_text"] = ref_text
|
|
146
|
+
elif "reference_text" in params:
|
|
147
|
+
kwargs["reference_text"] = ref_text
|
|
148
|
+
elif "prompt_text" in params:
|
|
149
|
+
kwargs["prompt_text"] = ref_text
|
|
150
|
+
|
|
151
|
+
if language:
|
|
152
|
+
if "language" in params:
|
|
153
|
+
kwargs["language"] = language
|
|
154
|
+
elif "lang" in params:
|
|
155
|
+
kwargs["lang"] = language
|
|
156
|
+
|
|
157
|
+
ns = num_step if num_step is not None else OMNIVOICE_NUM_STEP
|
|
158
|
+
if "num_step" in params:
|
|
159
|
+
kwargs["num_step"] = ns
|
|
160
|
+
elif "num_steps" in params:
|
|
161
|
+
kwargs["num_steps"] = ns
|
|
162
|
+
|
|
163
|
+
# Disable internal preprocessing/trimming if model supports it,
|
|
164
|
+
# because the worker already handled trimming via prepare_ref_audio.
|
|
165
|
+
if "preprocess_prompt" in params:
|
|
166
|
+
kwargs["preprocess_prompt"] = False
|
|
167
|
+
elif "preprocess" in params:
|
|
168
|
+
kwargs["preprocess"] = False
|
|
169
|
+
|
|
170
|
+
gs = guidance_scale if guidance_scale is not None else OMNIVOICE_GUIDANCE_SCALE
|
|
171
|
+
if "guidance_scale" in params:
|
|
172
|
+
kwargs["guidance_scale"] = gs
|
|
173
|
+
|
|
174
|
+
if "speed" in params:
|
|
175
|
+
kwargs["speed"] = OMNIVOICE_SPEED
|
|
176
|
+
|
|
177
|
+
if "use_cache" in params:
|
|
178
|
+
kwargs["use_cache"] = True
|
|
179
|
+
|
|
180
|
+
return kwargs
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def load_model(device: str) -> Any:
|
|
184
|
+
print("🔄 Đang tải model OmniVoice...", flush=True)
|
|
185
|
+
started_at = time.time()
|
|
186
|
+
|
|
187
|
+
from omnivoice import OmniVoice
|
|
188
|
+
|
|
189
|
+
model = OmniVoice.from_pretrained(
|
|
190
|
+
MODEL_ID,
|
|
191
|
+
device_map=device,
|
|
192
|
+
dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
|
193
|
+
)
|
|
194
|
+
cache_generate_signature(model)
|
|
195
|
+
import inspect
|
|
196
|
+
try:
|
|
197
|
+
model._omnivoice_prompt_params = set(inspect.signature(model.create_voice_clone_prompt).parameters.keys())
|
|
198
|
+
print(f"[model] voice prompt params cached: {sorted(model._omnivoice_prompt_params)}", flush=True)
|
|
199
|
+
except Exception:
|
|
200
|
+
model._omnivoice_prompt_params = set()
|
|
201
|
+
print(f"✅ Model loaded trong {time.time() - started_at:.1f}s", flush=True)
|
|
202
|
+
|
|
203
|
+
print("🔥 Đang warmup model...", flush=True)
|
|
204
|
+
try:
|
|
205
|
+
dummy_ref = "/tmp/warmup.wav"
|
|
206
|
+
if not os.path.exists(dummy_ref):
|
|
207
|
+
import numpy as np
|
|
208
|
+
sf.write(dummy_ref, np.zeros(SAMPLE_RATE, dtype="float32"), SAMPLE_RATE, format="WAV")
|
|
209
|
+
|
|
210
|
+
kwargs = build_generate_kwargs(model._omnivoice_generate_params, "Warmup", dummy_ref)
|
|
211
|
+
with torch.inference_mode(), autocast_context():
|
|
212
|
+
model.generate(**kwargs)
|
|
213
|
+
if torch.cuda.is_available():
|
|
214
|
+
torch.cuda.synchronize()
|
|
215
|
+
print("✅ Warmup thành công.", flush=True)
|
|
216
|
+
except Exception as exc:
|
|
217
|
+
print(f"⚠️ Warmup lỗi (bỏ qua): {exc}", flush=True)
|
|
218
|
+
|
|
219
|
+
return model
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _audio_to_wav_bytes(audio: Any) -> bytes:
|
|
223
|
+
import numpy as np
|
|
224
|
+
|
|
225
|
+
if isinstance(audio, bytes):
|
|
226
|
+
return audio
|
|
227
|
+
|
|
228
|
+
if isinstance(audio, io.BytesIO):
|
|
229
|
+
audio.seek(0)
|
|
230
|
+
return audio.read()
|
|
231
|
+
|
|
232
|
+
if isinstance(audio, dict):
|
|
233
|
+
for key in ("audio", "wav", "output", "samples", "waveform"):
|
|
234
|
+
if key in audio:
|
|
235
|
+
return _audio_to_wav_bytes(audio[key])
|
|
236
|
+
if audio:
|
|
237
|
+
return _audio_to_wav_bytes(audio[next(iter(audio))])
|
|
238
|
+
|
|
239
|
+
if isinstance(audio, (list, tuple)):
|
|
240
|
+
if not audio:
|
|
241
|
+
raise ValueError("model.generate returned empty audio list")
|
|
242
|
+
return _audio_to_wav_bytes(audio[0])
|
|
243
|
+
|
|
244
|
+
if isinstance(audio, torch.Tensor):
|
|
245
|
+
audio = audio.detach().float().cpu().numpy()
|
|
246
|
+
|
|
247
|
+
if isinstance(audio, np.ndarray):
|
|
248
|
+
audio_np = audio.squeeze().astype("float32", copy=False)
|
|
249
|
+
buffer = io.BytesIO()
|
|
250
|
+
sf.write(buffer, audio_np, SAMPLE_RATE, format="WAV")
|
|
251
|
+
return buffer.getvalue()
|
|
252
|
+
|
|
253
|
+
raise TypeError(f"Unsupported model.generate output type: {type(audio)!r}")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def prepare_ref_audio(ref_audio: str) -> str:
|
|
257
|
+
"""Trim leading/trailing silence and limit duration to REF_AUDIO_MAX_SECONDS."""
|
|
258
|
+
if REF_AUDIO_MAX_SECONDS <= 0:
|
|
259
|
+
return ref_audio
|
|
260
|
+
|
|
261
|
+
src = Path(ref_audio)
|
|
262
|
+
if not src.exists():
|
|
263
|
+
return ref_audio
|
|
264
|
+
|
|
265
|
+
REF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
266
|
+
trim_key = hashlib.md5(
|
|
267
|
+
f"{src}:{src.stat().st_mtime_ns}:{REF_AUDIO_MAX_SECONDS}".encode(),
|
|
268
|
+
usedforsecurity=False,
|
|
269
|
+
).hexdigest()
|
|
270
|
+
trimmed = REF_CACHE_DIR / f"{trim_key}.clean.wav"
|
|
271
|
+
if trimmed.exists() and trimmed.stat().st_size > 44:
|
|
272
|
+
return str(trimmed)
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
import numpy as np
|
|
276
|
+
audio, sr = sf.read(str(src), always_2d=False)
|
|
277
|
+
|
|
278
|
+
# Convert to mono if stereo
|
|
279
|
+
if audio.ndim > 1:
|
|
280
|
+
audio = np.mean(audio, axis=1)
|
|
281
|
+
|
|
282
|
+
# 1. Trim leading and trailing silence/noise below -40dB (amplitude 0.01)
|
|
283
|
+
threshold = 10 ** (-40.0 / 20.0) # 0.01
|
|
284
|
+
abs_audio = np.abs(audio)
|
|
285
|
+
nonzero = np.where(abs_audio > threshold)[0]
|
|
286
|
+
|
|
287
|
+
if len(nonzero) > 0:
|
|
288
|
+
start_idx = nonzero[0]
|
|
289
|
+
end_idx = nonzero[-1] + 1
|
|
290
|
+
# Add a small padding (e.g., 100ms) to avoid clipping spoken starts/ends
|
|
291
|
+
pad_samples = int(sr * 0.1)
|
|
292
|
+
start_idx = max(0, start_idx - pad_samples)
|
|
293
|
+
end_idx = min(len(audio), end_idx + pad_samples)
|
|
294
|
+
audio = audio[start_idx:end_idx]
|
|
295
|
+
|
|
296
|
+
# 2. Trim maximum length
|
|
297
|
+
max_samples = int(sr * REF_AUDIO_MAX_SECONDS)
|
|
298
|
+
if len(audio) > max_samples:
|
|
299
|
+
audio = audio[:max_samples]
|
|
300
|
+
|
|
301
|
+
sf.write(str(trimmed), audio, sr, format="WAV")
|
|
302
|
+
print(f"[ref] cleaned & trimmed {src.name} to {len(audio)/sr:.1f}s", flush=True)
|
|
303
|
+
return str(trimmed)
|
|
304
|
+
except Exception as exc:
|
|
305
|
+
print(f"[ref] prep skipped: {exc}", flush=True)
|
|
306
|
+
return ref_audio
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def build_voice_prompt_kwargs(params: set[str], ref_audio: str, ref_text: str | None = None) -> dict[str, Any]:
|
|
310
|
+
kwargs: dict[str, Any] = {}
|
|
311
|
+
if "ref_audio" in params:
|
|
312
|
+
kwargs["ref_audio"] = ref_audio
|
|
313
|
+
elif "reference_audio" in params:
|
|
314
|
+
kwargs["reference_audio"] = ref_audio
|
|
315
|
+
elif "reference_wav" in params:
|
|
316
|
+
kwargs["reference_wav"] = ref_audio
|
|
317
|
+
elif "prompt_audio" in params:
|
|
318
|
+
kwargs["prompt_audio"] = ref_audio
|
|
319
|
+
else:
|
|
320
|
+
kwargs["ref_audio"] = ref_audio
|
|
321
|
+
|
|
322
|
+
if ref_text:
|
|
323
|
+
if "ref_text" in params:
|
|
324
|
+
kwargs["ref_text"] = ref_text
|
|
325
|
+
elif "reference_text" in params:
|
|
326
|
+
kwargs["reference_text"] = ref_text
|
|
327
|
+
elif "prompt_text" in params:
|
|
328
|
+
kwargs["prompt_text"] = ref_text
|
|
329
|
+
return kwargs
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def get_voice_clone_prompt(model: Any, ref_audio: str, ref_text: str | None = None) -> Any | None:
|
|
333
|
+
"""Cache OmniVoice ref-audio encoding when API exposes create_voice_clone_prompt with LRU cache eviction."""
|
|
334
|
+
if not hasattr(model, "create_voice_clone_prompt"):
|
|
335
|
+
return None
|
|
336
|
+
|
|
337
|
+
params = getattr(model, "_omnivoice_prompt_params", set())
|
|
338
|
+
if not params:
|
|
339
|
+
return None
|
|
340
|
+
|
|
341
|
+
src = Path(ref_audio)
|
|
342
|
+
try:
|
|
343
|
+
mtime = src.stat().st_mtime_ns
|
|
344
|
+
except Exception:
|
|
345
|
+
mtime = 0
|
|
346
|
+
cache_key = hashlib.md5(f"{ref_audio}:{mtime}:{ref_text or ''}".encode(), usedforsecurity=False).hexdigest()
|
|
347
|
+
if cache_key in _voice_prompt_cache:
|
|
348
|
+
# Move to end (LRU behavior)
|
|
349
|
+
val = _voice_prompt_cache.pop(cache_key)
|
|
350
|
+
_voice_prompt_cache[cache_key] = val
|
|
351
|
+
return val
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
kwargs = build_voice_prompt_kwargs(params, ref_audio, ref_text)
|
|
355
|
+
started = time.time()
|
|
356
|
+
with torch.inference_mode(), autocast_context():
|
|
357
|
+
prompt = model.create_voice_clone_prompt(**kwargs)
|
|
358
|
+
if torch.cuda.is_available():
|
|
359
|
+
torch.cuda.synchronize()
|
|
360
|
+
|
|
361
|
+
# Evict oldest entry if cache exceeds 15 entries
|
|
362
|
+
if len(_voice_prompt_cache) >= 15:
|
|
363
|
+
oldest_key = next(iter(_voice_prompt_cache))
|
|
364
|
+
oldest_prompt = _voice_prompt_cache.pop(oldest_key, None)
|
|
365
|
+
if oldest_prompt is not None:
|
|
366
|
+
del oldest_prompt
|
|
367
|
+
gc.collect()
|
|
368
|
+
if torch.cuda.is_available():
|
|
369
|
+
torch.cuda.empty_cache()
|
|
370
|
+
|
|
371
|
+
_voice_prompt_cache[cache_key] = prompt
|
|
372
|
+
print(f"[voice-cache] prompt built in {(time.time() - started) * 1000:.0f}ms", flush=True)
|
|
373
|
+
return prompt
|
|
374
|
+
except Exception as exc:
|
|
375
|
+
print(f"[voice-cache] prompt build skipped: {exc}", flush=True)
|
|
376
|
+
return None
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def run_tts(model: Any, text: str, ref_audio: str, ref_text: str | None = None, language: str | None = None,
|
|
380
|
+
num_step: int | None = None, guidance_scale: float | None = None) -> bytes:
|
|
381
|
+
|
|
382
|
+
# Pre-clean the reference audio to trim silence and leading/trailing noise
|
|
383
|
+
ref_audio = prepare_ref_audio(ref_audio)
|
|
384
|
+
|
|
385
|
+
params = getattr(model, "_omnivoice_generate_params", set())
|
|
386
|
+
kwargs = build_generate_kwargs(params, text, ref_audio, ref_text=ref_text, language=language,
|
|
387
|
+
num_step=num_step, guidance_scale=guidance_scale)
|
|
388
|
+
|
|
389
|
+
# OmniVoice specific: get prompt if supported
|
|
390
|
+
voice_prompt = get_voice_clone_prompt(model, ref_audio, ref_text)
|
|
391
|
+
if voice_prompt is not None and "voice_clone_prompt" in params:
|
|
392
|
+
# DO NOT pop ref_audio/prompt_audio. OmniVoice needs them as context
|
|
393
|
+
# even when pre-encoded prompt is provided in some API versions.
|
|
394
|
+
kwargs["voice_clone_prompt"] = voice_prompt
|
|
395
|
+
logger.info("[tts] using cached voice prompt")
|
|
396
|
+
|
|
397
|
+
if torch.cuda.is_available():
|
|
398
|
+
torch.cuda.reset_peak_memory_stats()
|
|
399
|
+
|
|
400
|
+
# Ensure we are generating NEW audio, not returning the reference
|
|
401
|
+
print(f"🎤 Generating TTS: {len(text)} chars...", flush=True)
|
|
402
|
+
with torch.inference_mode(), autocast_context():
|
|
403
|
+
output = model.generate(**kwargs)
|
|
404
|
+
|
|
405
|
+
if torch.cuda.is_available():
|
|
406
|
+
torch.cuda.synchronize()
|
|
407
|
+
|
|
408
|
+
# Validation: if output is the same as input ref_audio path, something is wrong
|
|
409
|
+
audio_bytes = _audio_to_wav_bytes(output)
|
|
410
|
+
|
|
411
|
+
# Post-process the generated audio to remove leading silence/noise, apply fade-in/out, and pad end
|
|
412
|
+
try:
|
|
413
|
+
import io as _io, soundfile as _sf, numpy as _np
|
|
414
|
+
_audio_arr, _sr = _sf.read(_io.BytesIO(audio_bytes), always_2d=False)
|
|
415
|
+
|
|
416
|
+
# 1. Apply smooth fade-out to the end (100ms) before padding silence
|
|
417
|
+
_fade_out_len = int(_sr * 0.1)
|
|
418
|
+
if len(_audio_arr) > _fade_out_len:
|
|
419
|
+
_fade_out_curve = 0.5 + 0.5 * _np.cos(_np.pi * _np.linspace(0, 1, _fade_out_len))
|
|
420
|
+
_audio_arr[-_fade_out_len:] = _audio_arr[-_fade_out_len:] * _fade_out_curve
|
|
421
|
+
|
|
422
|
+
# 4. Pad 0.5s silence at the end to prevent last-word truncation
|
|
423
|
+
_pad_len = int(_sr * 0.5)
|
|
424
|
+
_padded = _np.concatenate([_audio_arr, _np.zeros(_pad_len, dtype=_audio_arr.dtype)])
|
|
425
|
+
|
|
426
|
+
_buf = _io.BytesIO()
|
|
427
|
+
_sf.write(_buf, _padded, _sr, format='WAV')
|
|
428
|
+
audio_bytes = _buf.getvalue()
|
|
429
|
+
except Exception as _e:
|
|
430
|
+
print(f"[post-process] failed to clean audio: {_e}", flush=True)
|
|
431
|
+
|
|
432
|
+
return audio_bytes
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
async def download_ref_audio(client: httpx.AsyncClient, server_url: str, voice_url: str, max_seconds: float = REF_AUDIO_MAX_SECONDS) -> str:
|
|
436
|
+
REF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
437
|
+
key_material = f"{voice_url}:{max_seconds}"
|
|
438
|
+
url_hash = hashlib.md5(key_material.encode(), usedforsecurity=False).hexdigest()
|
|
439
|
+
local_path = REF_CACHE_DIR / f"{url_hash}.wav"
|
|
440
|
+
|
|
441
|
+
if local_path.exists() and local_path.stat().st_size > 44:
|
|
442
|
+
return str(local_path)
|
|
443
|
+
|
|
444
|
+
full_url = f"{server_url}{voice_url}" if voice_url.startswith("/") else voice_url
|
|
445
|
+
tmp_path = local_path.with_suffix(".tmp")
|
|
446
|
+
resp = await client.get(full_url, timeout=30)
|
|
447
|
+
resp.raise_for_status()
|
|
448
|
+
|
|
449
|
+
data = resp.content
|
|
450
|
+
if max_seconds > 0:
|
|
451
|
+
try:
|
|
452
|
+
import io
|
|
453
|
+
audio, sr = sf.read(io.BytesIO(data), always_2d=False)
|
|
454
|
+
max_samples = int(sr * max_seconds)
|
|
455
|
+
if len(audio) > max_samples:
|
|
456
|
+
buf = io.BytesIO()
|
|
457
|
+
sf.write(buf, audio[:max_samples], sr, format="WAV")
|
|
458
|
+
data = buf.getvalue()
|
|
459
|
+
print(f"[ref] trimmed download to {max_seconds:.1f}s", flush=True)
|
|
460
|
+
except Exception as exc:
|
|
461
|
+
print(f"[ref] download trim skipped: {exc}", flush=True)
|
|
462
|
+
|
|
463
|
+
tmp_path.write_bytes(data)
|
|
464
|
+
tmp_path.replace(local_path)
|
|
465
|
+
return str(local_path)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
async def send_json_safe(ws: Any, payload: dict[str, Any]) -> None:
|
|
469
|
+
try:
|
|
470
|
+
await ws.send(json.dumps(payload))
|
|
471
|
+
except Exception:
|
|
472
|
+
pass
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
async def send_status(ws: Any, status: str, queue_size: int | None = None, worker_session_id: str = "") -> None:
|
|
476
|
+
payload: dict[str, Any] = {"action": "status", "status": status, "worker_session_id": worker_session_id}
|
|
477
|
+
if queue_size is not None:
|
|
478
|
+
payload["queue_size"] = queue_size
|
|
479
|
+
await send_json_safe(ws, payload)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
async def process_task(model: Any, ws: Any, http_client: httpx.AsyncClient, server_url: str, data: dict[str, Any], worker_session_id: str = "") -> None:
|
|
483
|
+
task_id = data["task_id"]
|
|
484
|
+
text = data["text"]
|
|
485
|
+
voice_url = data["voice_api_url"]
|
|
486
|
+
ref_text = (data.get("voice_ref_text") or "").strip() or None
|
|
487
|
+
language = data.get("language")
|
|
488
|
+
|
|
489
|
+
short_text = text[:70] + ("..." if len(text) > 70 else "")
|
|
490
|
+
print(f"[task] {task_id} | {short_text}", flush=True)
|
|
491
|
+
|
|
492
|
+
try:
|
|
493
|
+
ref_started = time.time()
|
|
494
|
+
ref_path = await download_ref_audio(http_client, server_url, voice_url)
|
|
495
|
+
ref_ms = (time.time() - ref_started) * 1000
|
|
496
|
+
|
|
497
|
+
loop = asyncio.get_running_loop()
|
|
498
|
+
tts_started = time.time()
|
|
499
|
+
ns = data.get("num_step")
|
|
500
|
+
gs = data.get("guidance_scale")
|
|
501
|
+
result_audio = await loop.run_in_executor(
|
|
502
|
+
executor, run_tts, model, text, ref_path, ref_text, language, ns, gs
|
|
503
|
+
)
|
|
504
|
+
tts_ms = (time.time() - tts_started) * 1000
|
|
505
|
+
|
|
506
|
+
upload_started = time.time()
|
|
507
|
+
upload_url = f"{server_url}/api/tasks/{task_id}/complete"
|
|
508
|
+
upload_response = await http_client.post(
|
|
509
|
+
upload_url,
|
|
510
|
+
data={"worker_session_id": worker_session_id},
|
|
511
|
+
files={"audio": ("result.wav", result_audio, "audio/wav")},
|
|
512
|
+
timeout=120,
|
|
513
|
+
)
|
|
514
|
+
upload_response.raise_for_status()
|
|
515
|
+
upload_ms = (time.time() - upload_started) * 1000
|
|
516
|
+
|
|
517
|
+
await send_json_safe(ws, {"action": "task_completed", "task_id": task_id, "worker_session_id": worker_session_id})
|
|
518
|
+
audio_seconds = len(result_audio) / (SAMPLE_RATE * 2)
|
|
519
|
+
peak_mb = 0.0
|
|
520
|
+
if torch.cuda.is_available():
|
|
521
|
+
peak_mb = torch.cuda.max_memory_allocated() / 1024**2
|
|
522
|
+
print(
|
|
523
|
+
f"[ok] {task_id} ref={ref_ms:.0f}ms tts={tts_ms:.0f}ms upload={upload_ms:.0f}ms "
|
|
524
|
+
f"audio~{audio_seconds:.1f}s peak={peak_mb:.0f}MB",
|
|
525
|
+
flush=True,
|
|
526
|
+
)
|
|
527
|
+
except Exception as exc:
|
|
528
|
+
print(f"[fail] {task_id}: {exc}", flush=True)
|
|
529
|
+
await send_json_safe(ws, {"action": "task_failed", "task_id": task_id, "error": str(exc), "worker_session_id": worker_session_id})
|
|
530
|
+
finally:
|
|
531
|
+
gc.collect()
|
|
532
|
+
if torch.cuda.is_available():
|
|
533
|
+
torch.cuda.empty_cache()
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
async def task_consumer(
|
|
537
|
+
queue: asyncio.Queue[dict[str, Any]],
|
|
538
|
+
model: Any,
|
|
539
|
+
ws: Any,
|
|
540
|
+
http_client: httpx.AsyncClient,
|
|
541
|
+
server_url: str,
|
|
542
|
+
worker_session_id: str = "",
|
|
543
|
+
) -> None:
|
|
544
|
+
while True:
|
|
545
|
+
data = await queue.get()
|
|
546
|
+
try:
|
|
547
|
+
await send_status(ws, "BUSY", queue.qsize(), worker_session_id)
|
|
548
|
+
await process_task(model, ws, http_client, server_url, data, worker_session_id)
|
|
549
|
+
finally:
|
|
550
|
+
queue.task_done()
|
|
551
|
+
await send_status(ws, "IDLE" if queue.empty() else "BUSY", queue.qsize(), worker_session_id)
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
async def worker_loop(model: Any, server_url: str, email: str, worker_session_id: str = "") -> None:
|
|
555
|
+
ws_url = websocket_url(server_url)
|
|
556
|
+
limits = httpx.Limits(max_connections=8, max_keepalive_connections=4)
|
|
557
|
+
reconnect_start: float | None = None
|
|
558
|
+
|
|
559
|
+
async with httpx.AsyncClient(limits=limits, http2=True, timeout=60) as http_client:
|
|
560
|
+
while True:
|
|
561
|
+
consumer_task: asyncio.Task | None = None
|
|
562
|
+
task_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=TASK_QUEUE_MAXSIZE)
|
|
563
|
+
last_ping_received = time.time()
|
|
564
|
+
try:
|
|
565
|
+
print(f"[ws] Connecting: {ws_url}", flush=True)
|
|
566
|
+
async with websockets.connect(
|
|
567
|
+
ws_url,
|
|
568
|
+
open_timeout=30,
|
|
569
|
+
close_timeout=5,
|
|
570
|
+
ping_interval=None,
|
|
571
|
+
) as ws:
|
|
572
|
+
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu"
|
|
573
|
+
await ws.send(json.dumps({
|
|
574
|
+
"action": "register",
|
|
575
|
+
"email": email,
|
|
576
|
+
"worker_session_id": worker_session_id,
|
|
577
|
+
"gpu": gpu
|
|
578
|
+
}))
|
|
579
|
+
consumer_task = asyncio.create_task(
|
|
580
|
+
task_consumer(task_queue, model, ws, http_client, server_url, worker_session_id)
|
|
581
|
+
)
|
|
582
|
+
print(f"🚀 Connected: {email} (GPU: {gpu})", flush=True)
|
|
583
|
+
await send_status(ws, "IDLE", 0, worker_session_id)
|
|
584
|
+
reconnect_start = None
|
|
585
|
+
|
|
586
|
+
while True:
|
|
587
|
+
try:
|
|
588
|
+
raw = await asyncio.wait_for(ws.recv(), timeout=30.0)
|
|
589
|
+
except asyncio.TimeoutError:
|
|
590
|
+
if time.time() - last_ping_received > 30.0:
|
|
591
|
+
print("[ws] Heartbeat timeout - no ping for >30s", flush=True)
|
|
592
|
+
break
|
|
593
|
+
continue
|
|
594
|
+
|
|
595
|
+
data = json.loads(raw)
|
|
596
|
+
action = data.get("action")
|
|
597
|
+
|
|
598
|
+
if action == "run_tts":
|
|
599
|
+
try:
|
|
600
|
+
task_queue.put_nowait(data)
|
|
601
|
+
await send_status(ws, "BUSY", task_queue.qsize(), worker_session_id)
|
|
602
|
+
except asyncio.QueueFull:
|
|
603
|
+
await send_json_safe(ws, {
|
|
604
|
+
"action": "task_failed",
|
|
605
|
+
"task_id": data.get("task_id"),
|
|
606
|
+
"error": "Worker queue full",
|
|
607
|
+
"worker_session_id": worker_session_id,
|
|
608
|
+
})
|
|
609
|
+
elif action == "ping":
|
|
610
|
+
last_ping_received = time.time()
|
|
611
|
+
current_status = "IDLE" if task_queue.empty() else "BUSY"
|
|
612
|
+
await ws.send(json.dumps({
|
|
613
|
+
"action": "pong_status",
|
|
614
|
+
"status": current_status,
|
|
615
|
+
"worker_session_id": worker_session_id,
|
|
616
|
+
}))
|
|
617
|
+
elif action == "shutdown":
|
|
618
|
+
print("[ws] Server yêu cầu shutdown.", flush=True)
|
|
619
|
+
if consumer_task:
|
|
620
|
+
consumer_task.cancel()
|
|
621
|
+
return
|
|
622
|
+
except asyncio.CancelledError:
|
|
623
|
+
print("[ws] Worker cancelled. Exiting.", flush=True)
|
|
624
|
+
sys.exit(1)
|
|
625
|
+
except Exception as exc:
|
|
626
|
+
# Check for 4002 close frame / connection reject from server
|
|
627
|
+
if isinstance(exc, websockets.exceptions.ConnectionClosed):
|
|
628
|
+
if exc.code == 4002:
|
|
629
|
+
print(f"[ws] Connection rejected by server (code 4002: session lost/expired). Exiting.", flush=True)
|
|
630
|
+
sys.exit(1)
|
|
631
|
+
|
|
632
|
+
print(f"🔄 Reconnecting... ({exc})", flush=True)
|
|
633
|
+
if consumer_task:
|
|
634
|
+
consumer_task.cancel()
|
|
635
|
+
try:
|
|
636
|
+
await consumer_task
|
|
637
|
+
except asyncio.CancelledError:
|
|
638
|
+
pass
|
|
639
|
+
gc.collect()
|
|
640
|
+
if torch.cuda.is_available():
|
|
641
|
+
torch.cuda.empty_cache()
|
|
642
|
+
if reconnect_start is None:
|
|
643
|
+
reconnect_start = time.time()
|
|
644
|
+
elif time.time() - reconnect_start > 180:
|
|
645
|
+
print("[ws] Reconnection timeout >180s. Exiting.", flush=True)
|
|
646
|
+
sys.exit(1)
|
|
647
|
+
await asyncio.sleep(5)
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
def parse_args() -> argparse.Namespace:
|
|
651
|
+
parser = argparse.ArgumentParser(description="Colab OmniVoice TTS Worker")
|
|
652
|
+
parser.add_argument("--server-url", required=True)
|
|
653
|
+
parser.add_argument("--email", required=True)
|
|
654
|
+
parser.add_argument("--worker-session-id", default="")
|
|
655
|
+
return parser.parse_args()
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def main() -> None:
|
|
659
|
+
args = parse_args()
|
|
660
|
+
try:
|
|
661
|
+
server_url = normalize_server_url(args.server_url)
|
|
662
|
+
except ValueError as exc:
|
|
663
|
+
print(f"[error] {exc}", flush=True)
|
|
664
|
+
sys.exit(1)
|
|
665
|
+
|
|
666
|
+
email = args.email.strip()
|
|
667
|
+
if not email:
|
|
668
|
+
print("[error] EMAIL không được để trống.", flush=True)
|
|
669
|
+
sys.exit(1)
|
|
670
|
+
|
|
671
|
+
worker_session_id = args.worker_session_id.strip()
|
|
672
|
+
|
|
673
|
+
print(f"[fast-mode] num_step={OMNIVOICE_NUM_STEP} guidance_scale={OMNIVOICE_GUIDANCE_SCALE} ref_max={REF_AUDIO_MAX_SECONDS}s speed={OMNIVOICE_SPEED}", flush=True)
|
|
674
|
+
model = load_model(detect_device())
|
|
675
|
+
asyncio.run(worker_loop(model, server_url, email, worker_session_id))
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
if __name__ == "__main__":
|
|
679
|
+
main()
|