ttsd-colabcli 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. package/cli.js +148 -0
  2. package/core/app/__init__.py +0 -0
  3. package/core/app/colab_cli/__init__.py +0 -0
  4. package/core/app/colab_cli/__pycache__/__init__.cpython-312.pyc +0 -0
  5. package/core/app/colab_cli/__pycache__/auth.cpython-312.pyc +0 -0
  6. package/core/app/colab_cli/__pycache__/auto_update.cpython-312.pyc +0 -0
  7. package/core/app/colab_cli/__pycache__/cli.cpython-312.pyc +0 -0
  8. package/core/app/colab_cli/__pycache__/client.cpython-312.pyc +0 -0
  9. package/core/app/colab_cli/__pycache__/common.cpython-312.pyc +0 -0
  10. package/core/app/colab_cli/__pycache__/console.cpython-312.pyc +0 -0
  11. package/core/app/colab_cli/__pycache__/contents.cpython-312.pyc +0 -0
  12. package/core/app/colab_cli/__pycache__/history.cpython-312.pyc +0 -0
  13. package/core/app/colab_cli/__pycache__/runtime.cpython-312.pyc +0 -0
  14. package/core/app/colab_cli/__pycache__/state.cpython-312.pyc +0 -0
  15. package/core/app/colab_cli/__pycache__/utils.cpython-312.pyc +0 -0
  16. package/core/app/colab_cli/auth.py +278 -0
  17. package/core/app/colab_cli/auto_update.py +248 -0
  18. package/core/app/colab_cli/cli.py +155 -0
  19. package/core/app/colab_cli/client.py +310 -0
  20. package/core/app/colab_cli/commands/__init__.py +14 -0
  21. package/core/app/colab_cli/commands/__pycache__/__init__.cpython-312.pyc +0 -0
  22. package/core/app/colab_cli/commands/__pycache__/automation.cpython-312.pyc +0 -0
  23. package/core/app/colab_cli/commands/__pycache__/execution.cpython-312.pyc +0 -0
  24. package/core/app/colab_cli/commands/__pycache__/files.cpython-312.pyc +0 -0
  25. package/core/app/colab_cli/commands/__pycache__/run.cpython-312.pyc +0 -0
  26. package/core/app/colab_cli/commands/__pycache__/session.cpython-312.pyc +0 -0
  27. package/core/app/colab_cli/commands/__pycache__/utility.cpython-312.pyc +0 -0
  28. package/core/app/colab_cli/commands/automation.py +265 -0
  29. package/core/app/colab_cli/commands/execution.py +362 -0
  30. package/core/app/colab_cli/commands/files.py +204 -0
  31. package/core/app/colab_cli/commands/run.py +477 -0
  32. package/core/app/colab_cli/commands/session.py +519 -0
  33. package/core/app/colab_cli/commands/utility.py +436 -0
  34. package/core/app/colab_cli/common.py +185 -0
  35. package/core/app/colab_cli/console.py +172 -0
  36. package/core/app/colab_cli/contents.py +93 -0
  37. package/core/app/colab_cli/converter.py +184 -0
  38. package/core/app/colab_cli/history.py +65 -0
  39. package/core/app/colab_cli/oauth_config.json +11 -0
  40. package/core/app/colab_cli/repl.py +173 -0
  41. package/core/app/colab_cli/runtime.py +262 -0
  42. package/core/app/colab_cli/state.py +156 -0
  43. package/core/app/colab_cli/utils.py +85 -0
  44. package/core/colab/worker.py +679 -0
  45. package/core/daemon.py +184 -0
  46. package/core/requirements.txt +8 -0
  47. package/package.json +22 -0
@@ -0,0 +1,679 @@
1
+
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import asyncio
6
+ import hashlib
7
+ import io
8
+ import json
9
+ import logging
10
+ import os
11
+ import sys
12
+ import time
13
+ import gc
14
+ from concurrent.futures import ThreadPoolExecutor
15
+
16
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
17
+ logger = logging.getLogger("worker")
18
+ from contextlib import nullcontext
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ import httpx
23
+ import soundfile as sf
24
+ import torch
25
+ import websockets
26
+
27
+ SAMPLE_RATE = 24000
28
+ REF_CACHE_DIR = Path("/tmp/omnivoice_refs")
29
+ MODEL_ID = "k2-fsa/OmniVoice"
30
+ TASK_QUEUE_MAXSIZE = 16
31
+ OMNIVOICE_NUM_STEP = int(os.getenv("OMNIVOICE_NUM_STEP", "16")) # Lower steps = faster, 32=best quality
32
+ OMNIVOICE_GUIDANCE_SCALE = float(os.getenv("OMNIVOICE_GUIDANCE_SCALE", "4.0")) # Best voice guidance, higher = tighter voice clone
33
+ _REF_MAX_RAW = float(os.getenv("REF_AUDIO_MAX_SECONDS", "15")) # Increased to 15s for better clone quality
34
+ REF_AUDIO_MAX_SECONDS = max(1.0, min(30.0, _REF_MAX_RAW))
35
+ if REF_AUDIO_MAX_SECONDS != _REF_MAX_RAW:
36
+ print(f"[ref] REF_AUDIO_MAX_SECONDS clamped to {REF_AUDIO_MAX_SECONDS:.1f}s (was {_REF_MAX_RAW})", flush=True)
37
+ OMNIVOICE_SPEED = float(os.getenv("OMNIVOICE_SPEED", "1.0"))
38
+
39
+ executor = ThreadPoolExecutor(max_workers=1)
40
+ _voice_prompt_cache: dict[str, Any] = {}
41
+
42
+
43
+ def configure_torch_runtime() -> None:
44
+ """Tune PyTorch defaults for Colab T4 inference without changing model API."""
45
+ try:
46
+ torch.set_grad_enabled(False)
47
+ except Exception:
48
+ pass
49
+
50
+ if not torch.cuda.is_available():
51
+ return
52
+
53
+ try:
54
+ torch.backends.cudnn.benchmark = True
55
+ except Exception:
56
+ pass
57
+
58
+ try:
59
+ torch.backends.cuda.matmul.allow_tf32 = True
60
+ torch.backends.cudnn.allow_tf32 = True
61
+ except Exception:
62
+ pass
63
+
64
+ try:
65
+ torch.set_float32_matmul_precision("high")
66
+ except Exception:
67
+ pass
68
+
69
+
70
+ def autocast_context():
71
+ if torch.cuda.is_available():
72
+ return torch.autocast(device_type="cuda", dtype=torch.float16)
73
+ return nullcontext()
74
+
75
+
76
+ def normalize_server_url(server_url: str) -> str:
77
+ normalized = (server_url or "").strip().rstrip("/")
78
+ if not normalized.startswith(("http://", "https://")):
79
+ raise ValueError(f"SERVER_URL không hợp lệ: {server_url!r}")
80
+ return normalized
81
+
82
+
83
+ def websocket_url(server_url: str) -> str:
84
+ if server_url.startswith("https://"):
85
+ return "wss://" + server_url.removeprefix("https://") + "/ws/worker"
86
+ return "ws://" + server_url.removeprefix("http://") + "/ws/worker"
87
+
88
+
89
+ def detect_device() -> str:
90
+ configure_torch_runtime()
91
+ if torch.cuda.is_available():
92
+ gpu_name = torch.cuda.get_device_name(0)
93
+ total_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
94
+ print(f"✅ GPU: {gpu_name} | VRAM={total_gb:.1f}GB", flush=True)
95
+ return "cuda:0"
96
+ print("⚠️ KHÔNG có GPU! Chạy trên CPU.", flush=True)
97
+ return "cpu"
98
+
99
+
100
+ def cache_generate_signature(model: Any) -> set[str]:
101
+ import inspect
102
+
103
+ try:
104
+ params = set(inspect.signature(model.generate).parameters.keys())
105
+ except Exception:
106
+ params = set()
107
+ model._omnivoice_generate_params = params
108
+ print(f"[model] generate params cached: {sorted(params)}", flush=True)
109
+ return params
110
+
111
+
112
+ def build_generate_kwargs(
113
+ params: set[str],
114
+ text: str,
115
+ ref_audio: str,
116
+ ref_text: str | None = None,
117
+ language: str | None = None,
118
+ num_step: int | None = None,
119
+ guidance_scale: float | None = None,
120
+ ) -> dict[str, Any]:
121
+ kwargs: dict[str, Any] = {}
122
+
123
+ if "text" in params:
124
+ kwargs["text"] = text
125
+ elif "prompt" in params:
126
+ kwargs["prompt"] = text
127
+ elif "input_text" in params:
128
+ kwargs["input_text"] = text
129
+ else:
130
+ kwargs["text"] = text
131
+
132
+ if "ref_audio" in params:
133
+ kwargs["ref_audio"] = ref_audio
134
+ elif "reference_audio" in params:
135
+ kwargs["reference_audio"] = ref_audio
136
+ elif "reference_wav" in params:
137
+ kwargs["reference_wav"] = ref_audio
138
+ elif "prompt_audio" in params:
139
+ kwargs["prompt_audio"] = ref_audio
140
+ else:
141
+ kwargs["ref_audio"] = ref_audio
142
+
143
+ if ref_text:
144
+ if "ref_text" in params:
145
+ kwargs["ref_text"] = ref_text
146
+ elif "reference_text" in params:
147
+ kwargs["reference_text"] = ref_text
148
+ elif "prompt_text" in params:
149
+ kwargs["prompt_text"] = ref_text
150
+
151
+ if language:
152
+ if "language" in params:
153
+ kwargs["language"] = language
154
+ elif "lang" in params:
155
+ kwargs["lang"] = language
156
+
157
+ ns = num_step if num_step is not None else OMNIVOICE_NUM_STEP
158
+ if "num_step" in params:
159
+ kwargs["num_step"] = ns
160
+ elif "num_steps" in params:
161
+ kwargs["num_steps"] = ns
162
+
163
+ # Disable internal preprocessing/trimming if model supports it,
164
+ # because the worker already handled trimming via prepare_ref_audio.
165
+ if "preprocess_prompt" in params:
166
+ kwargs["preprocess_prompt"] = False
167
+ elif "preprocess" in params:
168
+ kwargs["preprocess"] = False
169
+
170
+ gs = guidance_scale if guidance_scale is not None else OMNIVOICE_GUIDANCE_SCALE
171
+ if "guidance_scale" in params:
172
+ kwargs["guidance_scale"] = gs
173
+
174
+ if "speed" in params:
175
+ kwargs["speed"] = OMNIVOICE_SPEED
176
+
177
+ if "use_cache" in params:
178
+ kwargs["use_cache"] = True
179
+
180
+ return kwargs
181
+
182
+
183
+ def load_model(device: str) -> Any:
184
+ print("🔄 Đang tải model OmniVoice...", flush=True)
185
+ started_at = time.time()
186
+
187
+ from omnivoice import OmniVoice
188
+
189
+ model = OmniVoice.from_pretrained(
190
+ MODEL_ID,
191
+ device_map=device,
192
+ dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
193
+ )
194
+ cache_generate_signature(model)
195
+ import inspect
196
+ try:
197
+ model._omnivoice_prompt_params = set(inspect.signature(model.create_voice_clone_prompt).parameters.keys())
198
+ print(f"[model] voice prompt params cached: {sorted(model._omnivoice_prompt_params)}", flush=True)
199
+ except Exception:
200
+ model._omnivoice_prompt_params = set()
201
+ print(f"✅ Model loaded trong {time.time() - started_at:.1f}s", flush=True)
202
+
203
+ print("🔥 Đang warmup model...", flush=True)
204
+ try:
205
+ dummy_ref = "/tmp/warmup.wav"
206
+ if not os.path.exists(dummy_ref):
207
+ import numpy as np
208
+ sf.write(dummy_ref, np.zeros(SAMPLE_RATE, dtype="float32"), SAMPLE_RATE, format="WAV")
209
+
210
+ kwargs = build_generate_kwargs(model._omnivoice_generate_params, "Warmup", dummy_ref)
211
+ with torch.inference_mode(), autocast_context():
212
+ model.generate(**kwargs)
213
+ if torch.cuda.is_available():
214
+ torch.cuda.synchronize()
215
+ print("✅ Warmup thành công.", flush=True)
216
+ except Exception as exc:
217
+ print(f"⚠️ Warmup lỗi (bỏ qua): {exc}", flush=True)
218
+
219
+ return model
220
+
221
+
222
+ def _audio_to_wav_bytes(audio: Any) -> bytes:
223
+ import numpy as np
224
+
225
+ if isinstance(audio, bytes):
226
+ return audio
227
+
228
+ if isinstance(audio, io.BytesIO):
229
+ audio.seek(0)
230
+ return audio.read()
231
+
232
+ if isinstance(audio, dict):
233
+ for key in ("audio", "wav", "output", "samples", "waveform"):
234
+ if key in audio:
235
+ return _audio_to_wav_bytes(audio[key])
236
+ if audio:
237
+ return _audio_to_wav_bytes(audio[next(iter(audio))])
238
+
239
+ if isinstance(audio, (list, tuple)):
240
+ if not audio:
241
+ raise ValueError("model.generate returned empty audio list")
242
+ return _audio_to_wav_bytes(audio[0])
243
+
244
+ if isinstance(audio, torch.Tensor):
245
+ audio = audio.detach().float().cpu().numpy()
246
+
247
+ if isinstance(audio, np.ndarray):
248
+ audio_np = audio.squeeze().astype("float32", copy=False)
249
+ buffer = io.BytesIO()
250
+ sf.write(buffer, audio_np, SAMPLE_RATE, format="WAV")
251
+ return buffer.getvalue()
252
+
253
+ raise TypeError(f"Unsupported model.generate output type: {type(audio)!r}")
254
+
255
+
256
+ def prepare_ref_audio(ref_audio: str) -> str:
257
+ """Trim leading/trailing silence and limit duration to REF_AUDIO_MAX_SECONDS."""
258
+ if REF_AUDIO_MAX_SECONDS <= 0:
259
+ return ref_audio
260
+
261
+ src = Path(ref_audio)
262
+ if not src.exists():
263
+ return ref_audio
264
+
265
+ REF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
266
+ trim_key = hashlib.md5(
267
+ f"{src}:{src.stat().st_mtime_ns}:{REF_AUDIO_MAX_SECONDS}".encode(),
268
+ usedforsecurity=False,
269
+ ).hexdigest()
270
+ trimmed = REF_CACHE_DIR / f"{trim_key}.clean.wav"
271
+ if trimmed.exists() and trimmed.stat().st_size > 44:
272
+ return str(trimmed)
273
+
274
+ try:
275
+ import numpy as np
276
+ audio, sr = sf.read(str(src), always_2d=False)
277
+
278
+ # Convert to mono if stereo
279
+ if audio.ndim > 1:
280
+ audio = np.mean(audio, axis=1)
281
+
282
+ # 1. Trim leading and trailing silence/noise below -40dB (amplitude 0.01)
283
+ threshold = 10 ** (-40.0 / 20.0) # 0.01
284
+ abs_audio = np.abs(audio)
285
+ nonzero = np.where(abs_audio > threshold)[0]
286
+
287
+ if len(nonzero) > 0:
288
+ start_idx = nonzero[0]
289
+ end_idx = nonzero[-1] + 1
290
+ # Add a small padding (e.g., 100ms) to avoid clipping spoken starts/ends
291
+ pad_samples = int(sr * 0.1)
292
+ start_idx = max(0, start_idx - pad_samples)
293
+ end_idx = min(len(audio), end_idx + pad_samples)
294
+ audio = audio[start_idx:end_idx]
295
+
296
+ # 2. Trim maximum length
297
+ max_samples = int(sr * REF_AUDIO_MAX_SECONDS)
298
+ if len(audio) > max_samples:
299
+ audio = audio[:max_samples]
300
+
301
+ sf.write(str(trimmed), audio, sr, format="WAV")
302
+ print(f"[ref] cleaned & trimmed {src.name} to {len(audio)/sr:.1f}s", flush=True)
303
+ return str(trimmed)
304
+ except Exception as exc:
305
+ print(f"[ref] prep skipped: {exc}", flush=True)
306
+ return ref_audio
307
+
308
+
309
+ def build_voice_prompt_kwargs(params: set[str], ref_audio: str, ref_text: str | None = None) -> dict[str, Any]:
310
+ kwargs: dict[str, Any] = {}
311
+ if "ref_audio" in params:
312
+ kwargs["ref_audio"] = ref_audio
313
+ elif "reference_audio" in params:
314
+ kwargs["reference_audio"] = ref_audio
315
+ elif "reference_wav" in params:
316
+ kwargs["reference_wav"] = ref_audio
317
+ elif "prompt_audio" in params:
318
+ kwargs["prompt_audio"] = ref_audio
319
+ else:
320
+ kwargs["ref_audio"] = ref_audio
321
+
322
+ if ref_text:
323
+ if "ref_text" in params:
324
+ kwargs["ref_text"] = ref_text
325
+ elif "reference_text" in params:
326
+ kwargs["reference_text"] = ref_text
327
+ elif "prompt_text" in params:
328
+ kwargs["prompt_text"] = ref_text
329
+ return kwargs
330
+
331
+
332
+ def get_voice_clone_prompt(model: Any, ref_audio: str, ref_text: str | None = None) -> Any | None:
333
+ """Cache OmniVoice ref-audio encoding when API exposes create_voice_clone_prompt with LRU cache eviction."""
334
+ if not hasattr(model, "create_voice_clone_prompt"):
335
+ return None
336
+
337
+ params = getattr(model, "_omnivoice_prompt_params", set())
338
+ if not params:
339
+ return None
340
+
341
+ src = Path(ref_audio)
342
+ try:
343
+ mtime = src.stat().st_mtime_ns
344
+ except Exception:
345
+ mtime = 0
346
+ cache_key = hashlib.md5(f"{ref_audio}:{mtime}:{ref_text or ''}".encode(), usedforsecurity=False).hexdigest()
347
+ if cache_key in _voice_prompt_cache:
348
+ # Move to end (LRU behavior)
349
+ val = _voice_prompt_cache.pop(cache_key)
350
+ _voice_prompt_cache[cache_key] = val
351
+ return val
352
+
353
+ try:
354
+ kwargs = build_voice_prompt_kwargs(params, ref_audio, ref_text)
355
+ started = time.time()
356
+ with torch.inference_mode(), autocast_context():
357
+ prompt = model.create_voice_clone_prompt(**kwargs)
358
+ if torch.cuda.is_available():
359
+ torch.cuda.synchronize()
360
+
361
+ # Evict oldest entry if cache exceeds 15 entries
362
+ if len(_voice_prompt_cache) >= 15:
363
+ oldest_key = next(iter(_voice_prompt_cache))
364
+ oldest_prompt = _voice_prompt_cache.pop(oldest_key, None)
365
+ if oldest_prompt is not None:
366
+ del oldest_prompt
367
+ gc.collect()
368
+ if torch.cuda.is_available():
369
+ torch.cuda.empty_cache()
370
+
371
+ _voice_prompt_cache[cache_key] = prompt
372
+ print(f"[voice-cache] prompt built in {(time.time() - started) * 1000:.0f}ms", flush=True)
373
+ return prompt
374
+ except Exception as exc:
375
+ print(f"[voice-cache] prompt build skipped: {exc}", flush=True)
376
+ return None
377
+
378
+
379
+ def run_tts(model: Any, text: str, ref_audio: str, ref_text: str | None = None, language: str | None = None,
380
+ num_step: int | None = None, guidance_scale: float | None = None) -> bytes:
381
+
382
+ # Pre-clean the reference audio to trim silence and leading/trailing noise
383
+ ref_audio = prepare_ref_audio(ref_audio)
384
+
385
+ params = getattr(model, "_omnivoice_generate_params", set())
386
+ kwargs = build_generate_kwargs(params, text, ref_audio, ref_text=ref_text, language=language,
387
+ num_step=num_step, guidance_scale=guidance_scale)
388
+
389
+ # OmniVoice specific: get prompt if supported
390
+ voice_prompt = get_voice_clone_prompt(model, ref_audio, ref_text)
391
+ if voice_prompt is not None and "voice_clone_prompt" in params:
392
+ # DO NOT pop ref_audio/prompt_audio. OmniVoice needs them as context
393
+ # even when pre-encoded prompt is provided in some API versions.
394
+ kwargs["voice_clone_prompt"] = voice_prompt
395
+ logger.info("[tts] using cached voice prompt")
396
+
397
+ if torch.cuda.is_available():
398
+ torch.cuda.reset_peak_memory_stats()
399
+
400
+ # Ensure we are generating NEW audio, not returning the reference
401
+ print(f"🎤 Generating TTS: {len(text)} chars...", flush=True)
402
+ with torch.inference_mode(), autocast_context():
403
+ output = model.generate(**kwargs)
404
+
405
+ if torch.cuda.is_available():
406
+ torch.cuda.synchronize()
407
+
408
+ # Validation: if output is the same as input ref_audio path, something is wrong
409
+ audio_bytes = _audio_to_wav_bytes(output)
410
+
411
+ # Post-process the generated audio to remove leading silence/noise, apply fade-in/out, and pad end
412
+ try:
413
+ import io as _io, soundfile as _sf, numpy as _np
414
+ _audio_arr, _sr = _sf.read(_io.BytesIO(audio_bytes), always_2d=False)
415
+
416
+ # 1. Apply smooth fade-out to the end (100ms) before padding silence
417
+ _fade_out_len = int(_sr * 0.1)
418
+ if len(_audio_arr) > _fade_out_len:
419
+ _fade_out_curve = 0.5 + 0.5 * _np.cos(_np.pi * _np.linspace(0, 1, _fade_out_len))
420
+ _audio_arr[-_fade_out_len:] = _audio_arr[-_fade_out_len:] * _fade_out_curve
421
+
422
+ # 4. Pad 0.5s silence at the end to prevent last-word truncation
423
+ _pad_len = int(_sr * 0.5)
424
+ _padded = _np.concatenate([_audio_arr, _np.zeros(_pad_len, dtype=_audio_arr.dtype)])
425
+
426
+ _buf = _io.BytesIO()
427
+ _sf.write(_buf, _padded, _sr, format='WAV')
428
+ audio_bytes = _buf.getvalue()
429
+ except Exception as _e:
430
+ print(f"[post-process] failed to clean audio: {_e}", flush=True)
431
+
432
+ return audio_bytes
433
+
434
+
435
+ async def download_ref_audio(client: httpx.AsyncClient, server_url: str, voice_url: str, max_seconds: float = REF_AUDIO_MAX_SECONDS) -> str:
436
+ REF_CACHE_DIR.mkdir(parents=True, exist_ok=True)
437
+ key_material = f"{voice_url}:{max_seconds}"
438
+ url_hash = hashlib.md5(key_material.encode(), usedforsecurity=False).hexdigest()
439
+ local_path = REF_CACHE_DIR / f"{url_hash}.wav"
440
+
441
+ if local_path.exists() and local_path.stat().st_size > 44:
442
+ return str(local_path)
443
+
444
+ full_url = f"{server_url}{voice_url}" if voice_url.startswith("/") else voice_url
445
+ tmp_path = local_path.with_suffix(".tmp")
446
+ resp = await client.get(full_url, timeout=30)
447
+ resp.raise_for_status()
448
+
449
+ data = resp.content
450
+ if max_seconds > 0:
451
+ try:
452
+ import io
453
+ audio, sr = sf.read(io.BytesIO(data), always_2d=False)
454
+ max_samples = int(sr * max_seconds)
455
+ if len(audio) > max_samples:
456
+ buf = io.BytesIO()
457
+ sf.write(buf, audio[:max_samples], sr, format="WAV")
458
+ data = buf.getvalue()
459
+ print(f"[ref] trimmed download to {max_seconds:.1f}s", flush=True)
460
+ except Exception as exc:
461
+ print(f"[ref] download trim skipped: {exc}", flush=True)
462
+
463
+ tmp_path.write_bytes(data)
464
+ tmp_path.replace(local_path)
465
+ return str(local_path)
466
+
467
+
468
+ async def send_json_safe(ws: Any, payload: dict[str, Any]) -> None:
469
+ try:
470
+ await ws.send(json.dumps(payload))
471
+ except Exception:
472
+ pass
473
+
474
+
475
+ async def send_status(ws: Any, status: str, queue_size: int | None = None, worker_session_id: str = "") -> None:
476
+ payload: dict[str, Any] = {"action": "status", "status": status, "worker_session_id": worker_session_id}
477
+ if queue_size is not None:
478
+ payload["queue_size"] = queue_size
479
+ await send_json_safe(ws, payload)
480
+
481
+
482
+ async def process_task(model: Any, ws: Any, http_client: httpx.AsyncClient, server_url: str, data: dict[str, Any], worker_session_id: str = "") -> None:
483
+ task_id = data["task_id"]
484
+ text = data["text"]
485
+ voice_url = data["voice_api_url"]
486
+ ref_text = (data.get("voice_ref_text") or "").strip() or None
487
+ language = data.get("language")
488
+
489
+ short_text = text[:70] + ("..." if len(text) > 70 else "")
490
+ print(f"[task] {task_id} | {short_text}", flush=True)
491
+
492
+ try:
493
+ ref_started = time.time()
494
+ ref_path = await download_ref_audio(http_client, server_url, voice_url)
495
+ ref_ms = (time.time() - ref_started) * 1000
496
+
497
+ loop = asyncio.get_running_loop()
498
+ tts_started = time.time()
499
+ ns = data.get("num_step")
500
+ gs = data.get("guidance_scale")
501
+ result_audio = await loop.run_in_executor(
502
+ executor, run_tts, model, text, ref_path, ref_text, language, ns, gs
503
+ )
504
+ tts_ms = (time.time() - tts_started) * 1000
505
+
506
+ upload_started = time.time()
507
+ upload_url = f"{server_url}/api/tasks/{task_id}/complete"
508
+ upload_response = await http_client.post(
509
+ upload_url,
510
+ data={"worker_session_id": worker_session_id},
511
+ files={"audio": ("result.wav", result_audio, "audio/wav")},
512
+ timeout=120,
513
+ )
514
+ upload_response.raise_for_status()
515
+ upload_ms = (time.time() - upload_started) * 1000
516
+
517
+ await send_json_safe(ws, {"action": "task_completed", "task_id": task_id, "worker_session_id": worker_session_id})
518
+ audio_seconds = len(result_audio) / (SAMPLE_RATE * 2)
519
+ peak_mb = 0.0
520
+ if torch.cuda.is_available():
521
+ peak_mb = torch.cuda.max_memory_allocated() / 1024**2
522
+ print(
523
+ f"[ok] {task_id} ref={ref_ms:.0f}ms tts={tts_ms:.0f}ms upload={upload_ms:.0f}ms "
524
+ f"audio~{audio_seconds:.1f}s peak={peak_mb:.0f}MB",
525
+ flush=True,
526
+ )
527
+ except Exception as exc:
528
+ print(f"[fail] {task_id}: {exc}", flush=True)
529
+ await send_json_safe(ws, {"action": "task_failed", "task_id": task_id, "error": str(exc), "worker_session_id": worker_session_id})
530
+ finally:
531
+ gc.collect()
532
+ if torch.cuda.is_available():
533
+ torch.cuda.empty_cache()
534
+
535
+
536
+ async def task_consumer(
537
+ queue: asyncio.Queue[dict[str, Any]],
538
+ model: Any,
539
+ ws: Any,
540
+ http_client: httpx.AsyncClient,
541
+ server_url: str,
542
+ worker_session_id: str = "",
543
+ ) -> None:
544
+ while True:
545
+ data = await queue.get()
546
+ try:
547
+ await send_status(ws, "BUSY", queue.qsize(), worker_session_id)
548
+ await process_task(model, ws, http_client, server_url, data, worker_session_id)
549
+ finally:
550
+ queue.task_done()
551
+ await send_status(ws, "IDLE" if queue.empty() else "BUSY", queue.qsize(), worker_session_id)
552
+
553
+
554
+ async def worker_loop(model: Any, server_url: str, email: str, worker_session_id: str = "") -> None:
555
+ ws_url = websocket_url(server_url)
556
+ limits = httpx.Limits(max_connections=8, max_keepalive_connections=4)
557
+ reconnect_start: float | None = None
558
+
559
+ async with httpx.AsyncClient(limits=limits, http2=True, timeout=60) as http_client:
560
+ while True:
561
+ consumer_task: asyncio.Task | None = None
562
+ task_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=TASK_QUEUE_MAXSIZE)
563
+ last_ping_received = time.time()
564
+ try:
565
+ print(f"[ws] Connecting: {ws_url}", flush=True)
566
+ async with websockets.connect(
567
+ ws_url,
568
+ open_timeout=30,
569
+ close_timeout=5,
570
+ ping_interval=None,
571
+ ) as ws:
572
+ gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu"
573
+ await ws.send(json.dumps({
574
+ "action": "register",
575
+ "email": email,
576
+ "worker_session_id": worker_session_id,
577
+ "gpu": gpu
578
+ }))
579
+ consumer_task = asyncio.create_task(
580
+ task_consumer(task_queue, model, ws, http_client, server_url, worker_session_id)
581
+ )
582
+ print(f"🚀 Connected: {email} (GPU: {gpu})", flush=True)
583
+ await send_status(ws, "IDLE", 0, worker_session_id)
584
+ reconnect_start = None
585
+
586
+ while True:
587
+ try:
588
+ raw = await asyncio.wait_for(ws.recv(), timeout=30.0)
589
+ except asyncio.TimeoutError:
590
+ if time.time() - last_ping_received > 30.0:
591
+ print("[ws] Heartbeat timeout - no ping for >30s", flush=True)
592
+ break
593
+ continue
594
+
595
+ data = json.loads(raw)
596
+ action = data.get("action")
597
+
598
+ if action == "run_tts":
599
+ try:
600
+ task_queue.put_nowait(data)
601
+ await send_status(ws, "BUSY", task_queue.qsize(), worker_session_id)
602
+ except asyncio.QueueFull:
603
+ await send_json_safe(ws, {
604
+ "action": "task_failed",
605
+ "task_id": data.get("task_id"),
606
+ "error": "Worker queue full",
607
+ "worker_session_id": worker_session_id,
608
+ })
609
+ elif action == "ping":
610
+ last_ping_received = time.time()
611
+ current_status = "IDLE" if task_queue.empty() else "BUSY"
612
+ await ws.send(json.dumps({
613
+ "action": "pong_status",
614
+ "status": current_status,
615
+ "worker_session_id": worker_session_id,
616
+ }))
617
+ elif action == "shutdown":
618
+ print("[ws] Server yêu cầu shutdown.", flush=True)
619
+ if consumer_task:
620
+ consumer_task.cancel()
621
+ return
622
+ except asyncio.CancelledError:
623
+ print("[ws] Worker cancelled. Exiting.", flush=True)
624
+ sys.exit(1)
625
+ except Exception as exc:
626
+ # Check for 4002 close frame / connection reject from server
627
+ if isinstance(exc, websockets.exceptions.ConnectionClosed):
628
+ if exc.code == 4002:
629
+ print(f"[ws] Connection rejected by server (code 4002: session lost/expired). Exiting.", flush=True)
630
+ sys.exit(1)
631
+
632
+ print(f"🔄 Reconnecting... ({exc})", flush=True)
633
+ if consumer_task:
634
+ consumer_task.cancel()
635
+ try:
636
+ await consumer_task
637
+ except asyncio.CancelledError:
638
+ pass
639
+ gc.collect()
640
+ if torch.cuda.is_available():
641
+ torch.cuda.empty_cache()
642
+ if reconnect_start is None:
643
+ reconnect_start = time.time()
644
+ elif time.time() - reconnect_start > 180:
645
+ print("[ws] Reconnection timeout >180s. Exiting.", flush=True)
646
+ sys.exit(1)
647
+ await asyncio.sleep(5)
648
+
649
+
650
+ def parse_args() -> argparse.Namespace:
651
+ parser = argparse.ArgumentParser(description="Colab OmniVoice TTS Worker")
652
+ parser.add_argument("--server-url", required=True)
653
+ parser.add_argument("--email", required=True)
654
+ parser.add_argument("--worker-session-id", default="")
655
+ return parser.parse_args()
656
+
657
+
658
+ def main() -> None:
659
+ args = parse_args()
660
+ try:
661
+ server_url = normalize_server_url(args.server_url)
662
+ except ValueError as exc:
663
+ print(f"[error] {exc}", flush=True)
664
+ sys.exit(1)
665
+
666
+ email = args.email.strip()
667
+ if not email:
668
+ print("[error] EMAIL không được để trống.", flush=True)
669
+ sys.exit(1)
670
+
671
+ worker_session_id = args.worker_session_id.strip()
672
+
673
+ print(f"[fast-mode] num_step={OMNIVOICE_NUM_STEP} guidance_scale={OMNIVOICE_GUIDANCE_SCALE} ref_max={REF_AUDIO_MAX_SECONDS}s speed={OMNIVOICE_SPEED}", flush=True)
674
+ model = load_model(detect_device())
675
+ asyncio.run(worker_loop(model, server_url, email, worker_session_id))
676
+
677
+
678
+ if __name__ == "__main__":
679
+ main()