openspeechapi 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openspeech/__init__.py +75 -0
- openspeech/__main__.py +5 -0
- openspeech/cli.py +413 -0
- openspeech/client/__init__.py +4 -0
- openspeech/client/client.py +145 -0
- openspeech/config.py +212 -0
- openspeech/core/__init__.py +0 -0
- openspeech/core/base.py +75 -0
- openspeech/core/enums.py +39 -0
- openspeech/core/models.py +61 -0
- openspeech/core/registry.py +37 -0
- openspeech/core/settings.py +8 -0
- openspeech/demo.py +675 -0
- openspeech/dispatch/__init__.py +0 -0
- openspeech/dispatch/context.py +34 -0
- openspeech/dispatch/dispatcher.py +661 -0
- openspeech/dispatch/executors/__init__.py +0 -0
- openspeech/dispatch/executors/base.py +34 -0
- openspeech/dispatch/executors/in_process.py +66 -0
- openspeech/dispatch/executors/remote.py +64 -0
- openspeech/dispatch/executors/subprocess_exec.py +446 -0
- openspeech/dispatch/fanout.py +95 -0
- openspeech/dispatch/filters.py +73 -0
- openspeech/dispatch/lifecycle.py +178 -0
- openspeech/dispatch/watcher.py +82 -0
- openspeech/engine_catalog.py +236 -0
- openspeech/engine_registry.yaml +347 -0
- openspeech/exceptions.py +51 -0
- openspeech/factory.py +325 -0
- openspeech/local_engines/__init__.py +12 -0
- openspeech/local_engines/aim_resolver.py +91 -0
- openspeech/local_engines/backends/__init__.py +1 -0
- openspeech/local_engines/backends/docker_backend.py +490 -0
- openspeech/local_engines/backends/native_backend.py +902 -0
- openspeech/local_engines/base.py +30 -0
- openspeech/local_engines/engines/__init__.py +1 -0
- openspeech/local_engines/engines/faster_whisper.py +36 -0
- openspeech/local_engines/engines/fish_speech.py +33 -0
- openspeech/local_engines/engines/sherpa_onnx.py +56 -0
- openspeech/local_engines/engines/whisper.py +41 -0
- openspeech/local_engines/engines/whisperlivekit.py +60 -0
- openspeech/local_engines/manager.py +208 -0
- openspeech/local_engines/models.py +50 -0
- openspeech/local_engines/progress.py +69 -0
- openspeech/local_engines/registry.py +19 -0
- openspeech/local_engines/task_store.py +52 -0
- openspeech/local_engines/tasks.py +71 -0
- openspeech/logging_config.py +607 -0
- openspeech/observe/__init__.py +0 -0
- openspeech/observe/base.py +79 -0
- openspeech/observe/debug.py +44 -0
- openspeech/observe/latency.py +19 -0
- openspeech/observe/metrics.py +47 -0
- openspeech/observe/tracing.py +44 -0
- openspeech/observe/usage.py +27 -0
- openspeech/providers/__init__.py +0 -0
- openspeech/providers/_template.py +101 -0
- openspeech/providers/stt/__init__.py +0 -0
- openspeech/providers/stt/alibaba.py +86 -0
- openspeech/providers/stt/assemblyai.py +135 -0
- openspeech/providers/stt/azure_speech.py +99 -0
- openspeech/providers/stt/baidu.py +135 -0
- openspeech/providers/stt/deepgram.py +311 -0
- openspeech/providers/stt/elevenlabs.py +385 -0
- openspeech/providers/stt/faster_whisper.py +211 -0
- openspeech/providers/stt/google_cloud.py +106 -0
- openspeech/providers/stt/iflytek.py +427 -0
- openspeech/providers/stt/macos_speech.py +226 -0
- openspeech/providers/stt/openai.py +84 -0
- openspeech/providers/stt/sherpa_onnx.py +353 -0
- openspeech/providers/stt/tencent.py +212 -0
- openspeech/providers/stt/volcengine.py +107 -0
- openspeech/providers/stt/whisper.py +153 -0
- openspeech/providers/stt/whisperlivekit.py +530 -0
- openspeech/providers/stt/windows_speech.py +249 -0
- openspeech/providers/tts/__init__.py +0 -0
- openspeech/providers/tts/alibaba.py +95 -0
- openspeech/providers/tts/azure_speech.py +123 -0
- openspeech/providers/tts/baidu.py +143 -0
- openspeech/providers/tts/coqui.py +64 -0
- openspeech/providers/tts/cosyvoice.py +90 -0
- openspeech/providers/tts/deepgram.py +174 -0
- openspeech/providers/tts/elevenlabs.py +311 -0
- openspeech/providers/tts/fish_speech.py +158 -0
- openspeech/providers/tts/google_cloud.py +107 -0
- openspeech/providers/tts/iflytek.py +209 -0
- openspeech/providers/tts/macos_say.py +251 -0
- openspeech/providers/tts/minimax.py +122 -0
- openspeech/providers/tts/openai.py +104 -0
- openspeech/providers/tts/piper.py +104 -0
- openspeech/providers/tts/tencent.py +189 -0
- openspeech/providers/tts/volcengine.py +117 -0
- openspeech/providers/tts/windows_sapi.py +234 -0
- openspeech/server/__init__.py +1 -0
- openspeech/server/app.py +72 -0
- openspeech/server/auth.py +42 -0
- openspeech/server/middleware.py +75 -0
- openspeech/server/routes/__init__.py +1 -0
- openspeech/server/routes/management.py +848 -0
- openspeech/server/routes/stt.py +121 -0
- openspeech/server/routes/tts.py +159 -0
- openspeech/server/routes/webui.py +29 -0
- openspeech/server/webui/app.js +2649 -0
- openspeech/server/webui/index.html +216 -0
- openspeech/server/webui/styles.css +617 -0
- openspeech/server/ws/__init__.py +1 -0
- openspeech/server/ws/stt_stream.py +263 -0
- openspeech/server/ws/tts_stream.py +207 -0
- openspeech/telemetry/__init__.py +21 -0
- openspeech/telemetry/perf.py +307 -0
- openspeech/utils/__init__.py +5 -0
- openspeech/utils/audio_converter.py +406 -0
- openspeech/utils/audio_playback.py +156 -0
- openspeech/vendor_registry.yaml +74 -0
- openspeechapi-0.1.0.dist-info/METADATA +101 -0
- openspeechapi-0.1.0.dist-info/RECORD +118 -0
- openspeechapi-0.1.0.dist-info/WHEEL +4 -0
- openspeechapi-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""FanOut/FanIn — concurrent dispatch to multiple providers with merge strategies."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import asyncio
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any
|
|
8
|
+
from openspeech.exceptions import FanOutAllFailedError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class FanOutResult:
|
|
13
|
+
successes: dict[str, Any] = field(default_factory=dict)
|
|
14
|
+
errors: dict[str, Exception] = field(default_factory=dict)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MergeStrategy(ABC):
|
|
18
|
+
@abstractmethod
|
|
19
|
+
async def merge(self, results: dict[str, Any | Exception]) -> Any: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FirstCompleted(MergeStrategy):
|
|
23
|
+
async def merge(self, results: dict[str, Any | Exception]) -> Any:
|
|
24
|
+
successes = {k: v for k, v in results.items() if not isinstance(v, Exception)}
|
|
25
|
+
if not successes:
|
|
26
|
+
errors = {k: v for k, v in results.items() if isinstance(v, Exception)}
|
|
27
|
+
raise FanOutAllFailedError(errors)
|
|
28
|
+
return next(iter(successes.values()))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class HighestConfidence(MergeStrategy):
|
|
32
|
+
async def merge(self, results: dict[str, Any | Exception]) -> Any:
|
|
33
|
+
successes = {k: v for k, v in results.items() if not isinstance(v, Exception)}
|
|
34
|
+
if not successes:
|
|
35
|
+
errors = {k: v for k, v in results.items() if isinstance(v, Exception)}
|
|
36
|
+
raise FanOutAllFailedError(errors)
|
|
37
|
+
return max(successes.values(), key=lambda r: getattr(r, "confidence", 0) or 0)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CollectAll(MergeStrategy):
|
|
41
|
+
async def merge(self, results: dict[str, Any | Exception]) -> FanOutResult:
|
|
42
|
+
fan_result = FanOutResult()
|
|
43
|
+
for k, v in results.items():
|
|
44
|
+
if isinstance(v, Exception):
|
|
45
|
+
fan_result.errors[k] = v
|
|
46
|
+
else:
|
|
47
|
+
fan_result.successes[k] = v
|
|
48
|
+
return fan_result
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class CustomMerge(MergeStrategy):
|
|
52
|
+
def __init__(self, fn: Callable[[dict[str, Any | Exception]], Awaitable[Any]]) -> None:
|
|
53
|
+
self._fn = fn
|
|
54
|
+
|
|
55
|
+
async def merge(self, results: dict[str, Any | Exception]) -> Any:
|
|
56
|
+
return await self._fn(results)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def fan_out(tasks: dict[str, Awaitable[Any]], strategy: MergeStrategy) -> Any:
|
|
60
|
+
if isinstance(strategy, FirstCompleted):
|
|
61
|
+
return await _fan_out_first_completed(tasks, strategy)
|
|
62
|
+
|
|
63
|
+
results: dict[str, Any | Exception] = {}
|
|
64
|
+
|
|
65
|
+
async def _run(name: str, coro: Awaitable[Any]) -> None:
|
|
66
|
+
try:
|
|
67
|
+
results[name] = await coro
|
|
68
|
+
except Exception as e:
|
|
69
|
+
results[name] = e
|
|
70
|
+
|
|
71
|
+
await asyncio.gather(*[_run(name, coro) for name, coro in tasks.items()])
|
|
72
|
+
return await strategy.merge(results)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
async def _fan_out_first_completed(tasks: dict[str, Awaitable[Any]], strategy: FirstCompleted) -> Any:
|
|
76
|
+
results: dict[str, Any | Exception] = {}
|
|
77
|
+
done_event = asyncio.Event()
|
|
78
|
+
|
|
79
|
+
async def _run(name: str, coro: Awaitable[Any]) -> None:
|
|
80
|
+
try:
|
|
81
|
+
result = await coro
|
|
82
|
+
results[name] = result
|
|
83
|
+
done_event.set()
|
|
84
|
+
except Exception as e:
|
|
85
|
+
results[name] = e
|
|
86
|
+
if len(results) == len(tasks):
|
|
87
|
+
done_event.set()
|
|
88
|
+
|
|
89
|
+
async_tasks = [asyncio.create_task(_run(name, coro)) for name, coro in tasks.items()]
|
|
90
|
+
await done_event.wait()
|
|
91
|
+
for t in async_tasks:
|
|
92
|
+
if not t.done():
|
|
93
|
+
t.cancel()
|
|
94
|
+
await asyncio.gather(*async_tasks, return_exceptions=True)
|
|
95
|
+
return await strategy.merge(results)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Result filter chain — post-processing filters applied to provider results."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import Any, Generic, TypeVar
|
|
5
|
+
from openspeech.core.enums import AudioFormat
|
|
6
|
+
from openspeech.core.models import AudioData
|
|
7
|
+
|
|
8
|
+
T = TypeVar("T")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ResultFilter(ABC, Generic[T]):
|
|
12
|
+
def should_pass(self, result: T) -> bool:
|
|
13
|
+
return True
|
|
14
|
+
|
|
15
|
+
def transform(self, result: T) -> T:
|
|
16
|
+
return result
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConfidenceFilter(ResultFilter):
|
|
20
|
+
def __init__(self, min_confidence: float = 0.8) -> None:
|
|
21
|
+
self._min = min_confidence
|
|
22
|
+
|
|
23
|
+
def should_pass(self, result: Any) -> bool:
|
|
24
|
+
confidence = getattr(result, "confidence", None)
|
|
25
|
+
if confidence is None:
|
|
26
|
+
return True
|
|
27
|
+
return confidence >= self._min
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LanguageFilter(ResultFilter):
|
|
31
|
+
def __init__(self, allow: list[str]) -> None:
|
|
32
|
+
self._allow = set(allow)
|
|
33
|
+
|
|
34
|
+
def should_pass(self, result: Any) -> bool:
|
|
35
|
+
language = getattr(result, "language", None)
|
|
36
|
+
if language is None:
|
|
37
|
+
return True
|
|
38
|
+
return language in self._allow
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class DurationFilter(ResultFilter):
|
|
42
|
+
def __init__(self, min_ms: int = 100) -> None:
|
|
43
|
+
self._min_ms = min_ms
|
|
44
|
+
|
|
45
|
+
def should_pass(self, result: Any) -> bool:
|
|
46
|
+
duration = getattr(result, "duration_ms", None)
|
|
47
|
+
if duration is None:
|
|
48
|
+
return True
|
|
49
|
+
return duration >= self._min_ms
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AudioFormatFilter(ResultFilter):
|
|
53
|
+
def __init__(self, target: AudioFormat = AudioFormat.PCM_16K) -> None:
|
|
54
|
+
self._target = target
|
|
55
|
+
|
|
56
|
+
def transform(self, result: Any) -> Any:
|
|
57
|
+
if isinstance(result, AudioData) and result.format != self._target:
|
|
58
|
+
return AudioData(data=result.data, sample_rate=result.sample_rate,
|
|
59
|
+
channels=result.channels, format=self._target,
|
|
60
|
+
duration_ms=result.duration_ms)
|
|
61
|
+
return result
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class FilterChain:
|
|
65
|
+
def __init__(self, filters: list[ResultFilter]) -> None:
|
|
66
|
+
self._filters = filters
|
|
67
|
+
|
|
68
|
+
def apply(self, result: Any) -> Any | None:
|
|
69
|
+
for f in self._filters:
|
|
70
|
+
if not f.should_pass(result):
|
|
71
|
+
return None
|
|
72
|
+
result = f.transform(result)
|
|
73
|
+
return result
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""Provider lifecycle management — lazy start, TTL-based auto-stop."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
import time
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from openspeech.logging_config import logger
|
|
10
|
+
|
|
11
|
+
from openspeech.logging_config import bind_context
|
|
12
|
+
from openspeech.telemetry.perf import Event, PerfTimer, milestone
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ProviderState(str, Enum):
|
|
16
|
+
REGISTERED = "registered"
|
|
17
|
+
STARTING = "starting"
|
|
18
|
+
READY = "ready"
|
|
19
|
+
STOPPED = "stopped"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _ProviderEntry:
|
|
23
|
+
def __init__(self, name: str, handle: Any, keepalive: int) -> None:
|
|
24
|
+
self.name = name
|
|
25
|
+
self.handle = handle
|
|
26
|
+
self.keepalive = keepalive
|
|
27
|
+
self.state = ProviderState.REGISTERED
|
|
28
|
+
self.last_used: float = 0.0
|
|
29
|
+
self._lock = asyncio.Lock()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ProviderLifecycleManager:
|
|
33
|
+
"""Manages per-provider state, lazy start, and TTL-based auto-stop."""
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
self._entries: dict[str, _ProviderEntry] = {}
|
|
37
|
+
self._checker_task: asyncio.Task | None = None
|
|
38
|
+
self._check_interval: float = 30.0
|
|
39
|
+
self._shared_http_client: Any = None
|
|
40
|
+
|
|
41
|
+
def set_shared_http_client(self, client: Any) -> None:
|
|
42
|
+
"""Store a shared httpx.AsyncClient for injection into providers."""
|
|
43
|
+
self._shared_http_client = client
|
|
44
|
+
|
|
45
|
+
def register(self, name: str, handle: Any, keepalive: int = 0) -> None:
|
|
46
|
+
self._entries[name] = _ProviderEntry(name, handle, keepalive)
|
|
47
|
+
|
|
48
|
+
def unregister(self, name: str) -> None:
|
|
49
|
+
self._entries.pop(name, None)
|
|
50
|
+
|
|
51
|
+
def get_state(self, name: str) -> ProviderState | None:
|
|
52
|
+
entry = self._entries.get(name)
|
|
53
|
+
return entry.state if entry else None
|
|
54
|
+
|
|
55
|
+
def list_states(self) -> dict[str, str]:
|
|
56
|
+
return {name: entry.state.value for name, entry in self._entries.items()}
|
|
57
|
+
|
|
58
|
+
async def ensure_ready(self, name: str) -> None:
|
|
59
|
+
"""Ensure provider is READY. Starts it if needed. Thread-safe."""
|
|
60
|
+
entry = self._entries.get(name)
|
|
61
|
+
if entry is None:
|
|
62
|
+
from openspeech.exceptions import ProviderNotFoundError
|
|
63
|
+
raise ProviderNotFoundError(name)
|
|
64
|
+
|
|
65
|
+
if entry.state == ProviderState.READY:
|
|
66
|
+
entry.last_used = time.monotonic()
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
async with entry._lock:
|
|
70
|
+
# Double-check after acquiring lock
|
|
71
|
+
if entry.state == ProviderState.READY:
|
|
72
|
+
entry.last_used = time.monotonic()
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
with bind_context(provider=name, engine=name):
|
|
76
|
+
logger.info("lazy-starting provider '{}'", name)
|
|
77
|
+
entry.state = ProviderState.STARTING
|
|
78
|
+
try:
|
|
79
|
+
import dataclasses
|
|
80
|
+
from openspeech.core.settings import BaseSettings
|
|
81
|
+
handle = entry.handle
|
|
82
|
+
settings_cls = getattr(handle.provider_cls, "settings_cls", BaseSettings)
|
|
83
|
+
# Filter settings to only include fields the class accepts
|
|
84
|
+
valid_fields = {f.name for f in dataclasses.fields(settings_cls)}
|
|
85
|
+
filtered = {k: v for k, v in handle.settings_dict.items() if k in valid_fields}
|
|
86
|
+
settings = settings_cls(**filtered)
|
|
87
|
+
with PerfTimer(
|
|
88
|
+
Event.LIFECYCLE_PROVIDER_INIT,
|
|
89
|
+
exec_mode=handle.exec_mode.value,
|
|
90
|
+
):
|
|
91
|
+
await handle.executor.start(
|
|
92
|
+
handle.provider_cls, settings,
|
|
93
|
+
http_client=self._shared_http_client,
|
|
94
|
+
)
|
|
95
|
+
entry.state = ProviderState.READY
|
|
96
|
+
entry.last_used = time.monotonic()
|
|
97
|
+
milestone(
|
|
98
|
+
Event.LIFECYCLE_PROVIDER_READY,
|
|
99
|
+
exec_mode=handle.exec_mode.value,
|
|
100
|
+
)
|
|
101
|
+
logger.info("provider '{}' ready", name)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
entry.state = ProviderState.STOPPED
|
|
104
|
+
milestone(
|
|
105
|
+
Event.PROVIDER_ERROR,
|
|
106
|
+
phase="init",
|
|
107
|
+
error_type=type(e).__name__,
|
|
108
|
+
error_message=str(e),
|
|
109
|
+
)
|
|
110
|
+
logger.error("failed to start provider '{}': {}", name, e)
|
|
111
|
+
raise
|
|
112
|
+
|
|
113
|
+
def get_instance(self, name: str) -> Any | None:
|
|
114
|
+
"""Return the running provider instance, or None."""
|
|
115
|
+
entry = self._entries.get(name)
|
|
116
|
+
if entry is None or entry.state != ProviderState.READY:
|
|
117
|
+
return None
|
|
118
|
+
executor = entry.handle.executor
|
|
119
|
+
return getattr(executor, "_provider", None)
|
|
120
|
+
|
|
121
|
+
def touch(self, name: str) -> None:
|
|
122
|
+
"""Reset idle timer for a provider."""
|
|
123
|
+
entry = self._entries.get(name)
|
|
124
|
+
if entry:
|
|
125
|
+
entry.last_used = time.monotonic()
|
|
126
|
+
|
|
127
|
+
async def stop_provider(self, name: str, *, reason: str = "manual") -> None:
|
|
128
|
+
"""Stop a single provider."""
|
|
129
|
+
entry = self._entries.get(name)
|
|
130
|
+
if entry is None or entry.state not in (ProviderState.READY, ProviderState.STARTING):
|
|
131
|
+
return
|
|
132
|
+
with bind_context(provider=name, engine=name):
|
|
133
|
+
try:
|
|
134
|
+
with PerfTimer(Event.LIFECYCLE_PROVIDER_STOP, reason=reason):
|
|
135
|
+
await entry.handle.executor.stop()
|
|
136
|
+
entry.state = ProviderState.STOPPED
|
|
137
|
+
logger.info("provider '{}' stopped (reason={})", name, reason)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
logger.warning("error stopping provider '{}': {}", name, e)
|
|
140
|
+
entry.state = ProviderState.STOPPED
|
|
141
|
+
|
|
142
|
+
async def stop_all(self) -> None:
|
|
143
|
+
"""Stop all running providers and the idle checker."""
|
|
144
|
+
if self._checker_task:
|
|
145
|
+
self._checker_task.cancel()
|
|
146
|
+
try:
|
|
147
|
+
await self._checker_task
|
|
148
|
+
except asyncio.CancelledError:
|
|
149
|
+
pass
|
|
150
|
+
self._checker_task = None
|
|
151
|
+
|
|
152
|
+
for name, entry in self._entries.items():
|
|
153
|
+
if entry.state == ProviderState.READY:
|
|
154
|
+
await self.stop_provider(name)
|
|
155
|
+
|
|
156
|
+
def start_idle_checker(self) -> None:
|
|
157
|
+
"""Start background task that stops idle providers."""
|
|
158
|
+
if self._checker_task is None or self._checker_task.done():
|
|
159
|
+
self._checker_task = asyncio.create_task(self._idle_check_loop())
|
|
160
|
+
|
|
161
|
+
async def _idle_check_loop(self) -> None:
|
|
162
|
+
while True:
|
|
163
|
+
await asyncio.sleep(self._check_interval)
|
|
164
|
+
now = time.monotonic()
|
|
165
|
+
for name, entry in list(self._entries.items()):
|
|
166
|
+
if (entry.state == ProviderState.READY
|
|
167
|
+
and entry.keepalive > 0
|
|
168
|
+
and entry.last_used > 0
|
|
169
|
+
and (now - entry.last_used) > entry.keepalive):
|
|
170
|
+
idle_s = now - entry.last_used
|
|
171
|
+
milestone(
|
|
172
|
+
Event.LIFECYCLE_IDLE_RECYCLE,
|
|
173
|
+
provider=name,
|
|
174
|
+
engine=name,
|
|
175
|
+
idle_seconds=round(idle_s, 2),
|
|
176
|
+
keepalive=entry.keepalive,
|
|
177
|
+
)
|
|
178
|
+
await self.stop_provider(name, reason="idle_ttl")
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Config file watcher for hot-reload."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Awaitable, Callable
|
|
7
|
+
|
|
8
|
+
from openspeech.logging_config import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ConfigWatcher:
|
|
12
|
+
"""Watches providers.yaml and .env for changes, triggers reload callback."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
config_path: Path,
|
|
17
|
+
on_reload: Callable[[], Awaitable[dict[str, list[str]]]],
|
|
18
|
+
debounce_s: float = 1.0,
|
|
19
|
+
) -> None:
|
|
20
|
+
self._config_path = config_path
|
|
21
|
+
self._env_path = config_path.parent / ".env"
|
|
22
|
+
self._on_reload = on_reload
|
|
23
|
+
self._debounce_s = debounce_s
|
|
24
|
+
self._task: asyncio.Task | None = None
|
|
25
|
+
self._last_config_mtime: float = 0.0
|
|
26
|
+
self._last_env_mtime: float = 0.0
|
|
27
|
+
self._poll_interval: float = 2.0 # poll every 2 seconds
|
|
28
|
+
|
|
29
|
+
def start(self) -> None:
|
|
30
|
+
"""Start watching in background."""
|
|
31
|
+
self._last_config_mtime = self._get_mtime(self._config_path)
|
|
32
|
+
self._last_env_mtime = self._get_mtime(self._env_path)
|
|
33
|
+
if self._task is None or self._task.done():
|
|
34
|
+
self._task = asyncio.create_task(self._watch_loop())
|
|
35
|
+
logger.info(f"Config watcher started: {self._config_path}")
|
|
36
|
+
|
|
37
|
+
async def stop(self) -> None:
|
|
38
|
+
if self._task:
|
|
39
|
+
self._task.cancel()
|
|
40
|
+
try:
|
|
41
|
+
await self._task
|
|
42
|
+
except asyncio.CancelledError:
|
|
43
|
+
pass
|
|
44
|
+
self._task = None
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def _get_mtime(path: Path) -> float:
|
|
48
|
+
try:
|
|
49
|
+
return path.stat().st_mtime
|
|
50
|
+
except FileNotFoundError:
|
|
51
|
+
return 0.0
|
|
52
|
+
|
|
53
|
+
async def _watch_loop(self) -> None:
|
|
54
|
+
"""Poll for file changes and trigger reload."""
|
|
55
|
+
while True:
|
|
56
|
+
await asyncio.sleep(self._poll_interval)
|
|
57
|
+
changed = False
|
|
58
|
+
|
|
59
|
+
config_mtime = self._get_mtime(self._config_path)
|
|
60
|
+
if config_mtime > self._last_config_mtime:
|
|
61
|
+
self._last_config_mtime = config_mtime
|
|
62
|
+
changed = True
|
|
63
|
+
|
|
64
|
+
env_mtime = self._get_mtime(self._env_path)
|
|
65
|
+
if env_mtime > self._last_env_mtime:
|
|
66
|
+
self._last_env_mtime = env_mtime
|
|
67
|
+
# Re-load .env into os.environ
|
|
68
|
+
try:
|
|
69
|
+
from dotenv import load_dotenv
|
|
70
|
+
load_dotenv(self._env_path, override=True)
|
|
71
|
+
except ImportError:
|
|
72
|
+
pass
|
|
73
|
+
changed = True
|
|
74
|
+
|
|
75
|
+
if changed:
|
|
76
|
+
# Debounce: wait a bit to coalesce rapid saves
|
|
77
|
+
await asyncio.sleep(self._debounce_s)
|
|
78
|
+
try:
|
|
79
|
+
result = await self._on_reload()
|
|
80
|
+
logger.info(f"Config reloaded: {result}")
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.error(f"Config reload failed: {e}")
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Unified engine catalog — loads engine list from registry YAML,
|
|
2
|
+
enriches with runtime metadata (default_settings, field_options) from provider code.
|
|
3
|
+
|
|
4
|
+
Only "installed" engines appear in providers.yaml and are visible on Dashboard/Config/Lab.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import dataclasses
|
|
9
|
+
import sys
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import yaml
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class CatalogEntry:
|
|
19
|
+
"""A single engine entry in the catalog."""
|
|
20
|
+
name: str # Unique engine ID, e.g. "openai-stt", "fish-speech"
|
|
21
|
+
provider: str # factory _PROVIDER_MAP key
|
|
22
|
+
type: str # "stt" | "tts"
|
|
23
|
+
category: str # "cloud" | "local" | "native"
|
|
24
|
+
description: str # Human-readable description
|
|
25
|
+
default_alias: str # Default alias in config
|
|
26
|
+
display_name: str = "" # Human-friendly display name, e.g. "iFlytek STT"
|
|
27
|
+
vendor: str = "" # Vendor key from vendor_registry.yaml (cloud engines only)
|
|
28
|
+
default_settings: dict = field(default_factory=dict)
|
|
29
|
+
default_exec_mode: str = "remote"
|
|
30
|
+
pip_deps: list[str] = field(default_factory=list)
|
|
31
|
+
pip_extras: list[str] = field(default_factory=list) # pyproject.toml extras names
|
|
32
|
+
field_options: dict[str, list] = field(default_factory=dict)
|
|
33
|
+
platforms: list[str] = field(default_factory=list) # empty = all platforms
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def compatible(self) -> bool:
|
|
37
|
+
"""Whether this engine is compatible with the current platform."""
|
|
38
|
+
return not self.platforms or sys.platform in self.platforms
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
# Registry loading
|
|
43
|
+
# ---------------------------------------------------------------------------
|
|
44
|
+
|
|
45
|
+
_REGISTRY_PATH = Path(__file__).parent / "engine_registry.yaml"
|
|
46
|
+
_VENDOR_REGISTRY_PATH = Path(__file__).parent / "vendor_registry.yaml"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _load_registry(path: Path | None = None) -> list[dict]:
|
|
50
|
+
"""Load engine list from registry YAML file."""
|
|
51
|
+
path = path or _REGISTRY_PATH
|
|
52
|
+
if not path.exists():
|
|
53
|
+
return []
|
|
54
|
+
with open(path, encoding="utf-8") as f:
|
|
55
|
+
raw = yaml.safe_load(f) or {}
|
|
56
|
+
return raw.get("engines", [])
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _load_vendor_registry(path: Path | None = None) -> dict:
|
|
60
|
+
"""Load vendor registry from YAML file."""
|
|
61
|
+
path = path or _VENDOR_REGISTRY_PATH
|
|
62
|
+
if not path.exists():
|
|
63
|
+
return {}
|
|
64
|
+
with open(path, encoding="utf-8") as f:
|
|
65
|
+
raw = yaml.safe_load(f) or {}
|
|
66
|
+
return raw.get("vendors", {})
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_vendor_registry() -> dict:
|
|
70
|
+
"""Get vendor templates for provider credential fields."""
|
|
71
|
+
return _load_vendor_registry()
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _enrich_from_provider(entry: CatalogEntry) -> CatalogEntry:
|
|
75
|
+
"""Enrich a catalog entry with default_settings and field_options from provider code."""
|
|
76
|
+
try:
|
|
77
|
+
from openspeech.factory import _resolve
|
|
78
|
+
provider_cls, settings_cls = _resolve(entry.provider)
|
|
79
|
+
|
|
80
|
+
# Extract default_settings from settings dataclass defaults
|
|
81
|
+
if dataclasses.is_dataclass(settings_cls):
|
|
82
|
+
defaults = {}
|
|
83
|
+
for f in dataclasses.fields(settings_cls):
|
|
84
|
+
if f.default is not dataclasses.MISSING:
|
|
85
|
+
defaults[f.name] = f.default
|
|
86
|
+
elif f.default_factory is not dataclasses.MISSING:
|
|
87
|
+
defaults[f.name] = f.default_factory()
|
|
88
|
+
|
|
89
|
+
# Filter out vendor shared fields (e.g. api_key, api_secret)
|
|
90
|
+
# — these are inherited from vendor credentials, not engine settings
|
|
91
|
+
if entry.vendor:
|
|
92
|
+
vendor_registry = _load_vendor_registry()
|
|
93
|
+
vendor_tpl = vendor_registry.get(entry.vendor, {})
|
|
94
|
+
shared_keys = set(vendor_tpl.get("shared_fields", {}).keys())
|
|
95
|
+
defaults = {k: v for k, v in defaults.items() if k not in shared_keys}
|
|
96
|
+
|
|
97
|
+
entry.default_settings = defaults
|
|
98
|
+
|
|
99
|
+
# Extract field_options from provider class attribute
|
|
100
|
+
fo = getattr(provider_cls, "field_options", None)
|
|
101
|
+
if fo:
|
|
102
|
+
entry.field_options = dict(fo)
|
|
103
|
+
|
|
104
|
+
# Extract pip_deps from provider class if available
|
|
105
|
+
pd = getattr(provider_cls, "pip_deps", None)
|
|
106
|
+
if pd:
|
|
107
|
+
entry.pip_deps = list(pd)
|
|
108
|
+
|
|
109
|
+
except Exception:
|
|
110
|
+
pass # Provider not importable (missing deps) — keep registry defaults
|
|
111
|
+
|
|
112
|
+
return entry
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def build_catalog(registry_path: Path | None = None) -> list[CatalogEntry]:
|
|
116
|
+
"""Build the full engine catalog from registry YAML + provider metadata."""
|
|
117
|
+
raw_entries = _load_registry(registry_path)
|
|
118
|
+
entries: list[CatalogEntry] = []
|
|
119
|
+
|
|
120
|
+
for raw in raw_entries:
|
|
121
|
+
entry = CatalogEntry(
|
|
122
|
+
name=raw["name"],
|
|
123
|
+
provider=raw["provider"],
|
|
124
|
+
type=raw["type"],
|
|
125
|
+
category=raw.get("category", "cloud"),
|
|
126
|
+
description=raw.get("description", ""),
|
|
127
|
+
default_alias=raw.get("default_alias", raw["name"].replace("-", "_")),
|
|
128
|
+
display_name=raw.get("display_name", ""),
|
|
129
|
+
vendor=raw.get("vendor", ""),
|
|
130
|
+
default_exec_mode=raw.get("default_exec_mode", "remote"),
|
|
131
|
+
pip_extras=raw.get("pip_extras", []),
|
|
132
|
+
platforms=raw.get("platforms", []),
|
|
133
|
+
)
|
|
134
|
+
# Try to enrich with provider code metadata
|
|
135
|
+
entry = _enrich_from_provider(entry)
|
|
136
|
+
entries.append(entry)
|
|
137
|
+
|
|
138
|
+
return entries
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# ---------------------------------------------------------------------------
|
|
142
|
+
# Singleton & helpers
|
|
143
|
+
# ---------------------------------------------------------------------------
|
|
144
|
+
|
|
145
|
+
_catalog: list[CatalogEntry] | None = None
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_catalog() -> list[CatalogEntry]:
|
|
149
|
+
"""Get the singleton catalog."""
|
|
150
|
+
global _catalog
|
|
151
|
+
if _catalog is None:
|
|
152
|
+
_catalog = build_catalog()
|
|
153
|
+
return _catalog
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_catalog_entry(name: str) -> CatalogEntry | None:
|
|
157
|
+
"""Look up a catalog entry by name or default_alias.
|
|
158
|
+
|
|
159
|
+
For native meta-aliases (native_stt / native_tts), returns the
|
|
160
|
+
platform-specific entry (e.g. windows-stt on win32).
|
|
161
|
+
"""
|
|
162
|
+
from openspeech.factory import _NATIVE_ALIASES
|
|
163
|
+
|
|
164
|
+
# Check if this is a native meta-alias (by name or default_alias)
|
|
165
|
+
for meta_name, platform_map in _NATIVE_ALIASES.items():
|
|
166
|
+
meta_alias = meta_name.replace("-", "_")
|
|
167
|
+
if name in (meta_name, meta_alias):
|
|
168
|
+
concrete = platform_map.get(sys.platform)
|
|
169
|
+
if concrete:
|
|
170
|
+
for e in get_catalog():
|
|
171
|
+
if e.provider == concrete:
|
|
172
|
+
return e
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
for e in get_catalog():
|
|
176
|
+
if e.name == name or e.default_alias == name:
|
|
177
|
+
return e
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_installed_engines(config_path: Path) -> set[str]:
|
|
182
|
+
"""Return set of catalog engine names that are installed (present in config).
|
|
183
|
+
|
|
184
|
+
Native meta-aliases (native_stt/native_tts) are resolved to platform-specific
|
|
185
|
+
catalog entries (e.g. windows-stt on win32).
|
|
186
|
+
"""
|
|
187
|
+
if not config_path.exists():
|
|
188
|
+
return set()
|
|
189
|
+
with open(config_path, encoding="utf-8") as f:
|
|
190
|
+
raw = yaml.safe_load(f) or {}
|
|
191
|
+
|
|
192
|
+
# Support both new format (engines:) and old format (providers:)
|
|
193
|
+
engines_raw = raw.get("engines") or {}
|
|
194
|
+
if not engines_raw:
|
|
195
|
+
engines_raw = raw.get("providers") or {}
|
|
196
|
+
|
|
197
|
+
# Build reverse map: provider key → catalog name
|
|
198
|
+
installed = set()
|
|
199
|
+
catalog = get_catalog()
|
|
200
|
+
provider_to_catalog = {e.provider: e.name for e in catalog}
|
|
201
|
+
|
|
202
|
+
# Also build alias → catalog name map
|
|
203
|
+
alias_to_catalog = {e.default_alias: e.name for e in catalog}
|
|
204
|
+
|
|
205
|
+
# Load providers section for credential provider resolution
|
|
206
|
+
providers_section = raw.get("providers") or {}
|
|
207
|
+
is_new_format = "engines" in raw
|
|
208
|
+
|
|
209
|
+
for alias, spec in engines_raw.items():
|
|
210
|
+
if not isinstance(spec, dict):
|
|
211
|
+
continue
|
|
212
|
+
# Skip credential provider entries (no exec_mode)
|
|
213
|
+
if "exec_mode" not in spec and "provider" not in spec:
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
# Resolve via get_catalog_entry (handles native meta-aliases)
|
|
217
|
+
resolved = get_catalog_entry(alias)
|
|
218
|
+
if resolved:
|
|
219
|
+
installed.add(resolved.name)
|
|
220
|
+
continue
|
|
221
|
+
|
|
222
|
+
provider_val = spec.get("provider", "")
|
|
223
|
+
|
|
224
|
+
# In new format, provider might reference credential provider
|
|
225
|
+
if is_new_format and provider_val in providers_section:
|
|
226
|
+
# Try to resolve via catalog
|
|
227
|
+
for e in catalog:
|
|
228
|
+
if e.vendor == provider_val and e.default_alias == alias:
|
|
229
|
+
installed.add(e.name)
|
|
230
|
+
break
|
|
231
|
+
else:
|
|
232
|
+
# Direct factory key
|
|
233
|
+
if provider_val in provider_to_catalog:
|
|
234
|
+
installed.add(provider_to_catalog[provider_val])
|
|
235
|
+
|
|
236
|
+
return installed
|