mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,553 @@
|
|
1
|
+
"""
|
2
|
+
TextToSpeech module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
import onnxruntime as ort
|
8
|
+
import soundfile as sf
|
9
|
+
|
10
|
+
from ttstokenizer import IPATokenizer, TTSTokenizer
|
11
|
+
|
12
|
+
from .signal import Signal, SCIPY
|
13
|
+
|
14
|
+
TTS = SCIPY
|
15
|
+
except ImportError:
|
16
|
+
TTS = False
|
17
|
+
|
18
|
+
import json
|
19
|
+
import logging
|
20
|
+
|
21
|
+
from io import BytesIO
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import yaml
|
25
|
+
|
26
|
+
import numpy as np
|
27
|
+
|
28
|
+
from huggingface_hub.errors import HFValidationError
|
29
|
+
from transformers import SpeechT5Processor
|
30
|
+
from transformers.utils import cached_file
|
31
|
+
|
32
|
+
from ..base import Pipeline
|
33
|
+
|
34
|
+
# Logging configuration
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
|
38
|
+
class TextToSpeech(Pipeline):
|
39
|
+
"""
|
40
|
+
Generates speech from text.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, path=None, maxtokens=512, rate=22050):
|
44
|
+
"""
|
45
|
+
Creates a new TextToSpeech pipeline.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
path: optional model path
|
49
|
+
maxtokens: maximum number of tokens model can process, defaults to 512
|
50
|
+
rate: target sample rate, defaults to 22050
|
51
|
+
"""
|
52
|
+
|
53
|
+
if not TTS:
|
54
|
+
raise ImportError('TextToSpeech pipeline is not available - install "pipeline" extra to enable')
|
55
|
+
|
56
|
+
# Default path
|
57
|
+
path = path if path else "neuml/ljspeech-jets-onnx"
|
58
|
+
|
59
|
+
# Target sample rate
|
60
|
+
self.rate = rate
|
61
|
+
|
62
|
+
# Load target tts pipeline
|
63
|
+
self.pipeline = None
|
64
|
+
if self.hasfile(path, "model.onnx") and self.hasfile(path, "config.yaml"):
|
65
|
+
self.pipeline = ESPnet(path, maxtokens, self.providers())
|
66
|
+
elif self.hasfile(path, "model.onnx") and self.hasfile(path, "voices.json"):
|
67
|
+
self.pipeline = Kokoro(path, maxtokens, self.providers())
|
68
|
+
else:
|
69
|
+
self.pipeline = SpeechT5(path, maxtokens, self.providers())
|
70
|
+
|
71
|
+
def __call__(self, text, stream=False, speaker=1, encoding=None, **kwargs):
|
72
|
+
"""
|
73
|
+
Generates speech from text. Text longer than maxtokens will be batched and returned
|
74
|
+
as a single waveform per text input.
|
75
|
+
|
76
|
+
This method supports text as a string or a list. If the input is a string,
|
77
|
+
the return type is audio. If text is a list, the return type is a list.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
text: text|list
|
81
|
+
stream: stream response if True, defaults to False
|
82
|
+
speaker: speaker id, defaults to 1
|
83
|
+
encoding: optional audio encoding format
|
84
|
+
kwargs: additional keyword args
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
list of (audio, sample rate) or list of audio depending on encoding parameter
|
88
|
+
"""
|
89
|
+
|
90
|
+
# Convert results to a list if necessary
|
91
|
+
texts = [text] if isinstance(text, str) else text
|
92
|
+
|
93
|
+
# Streaming response
|
94
|
+
if stream:
|
95
|
+
return self.stream(texts, speaker, encoding)
|
96
|
+
|
97
|
+
# Transform text to speech
|
98
|
+
results = [self.execute(x, speaker, encoding, **kwargs) for x in texts]
|
99
|
+
|
100
|
+
# Return results
|
101
|
+
return results[0] if isinstance(text, str) else results
|
102
|
+
|
103
|
+
def providers(self):
|
104
|
+
"""
|
105
|
+
Returns a list of available and usable providers.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
list of available and usable providers
|
109
|
+
"""
|
110
|
+
|
111
|
+
# Create list of providers, prefer CUDA provider if available
|
112
|
+
# CUDA provider only available if GPU is available and onnxruntime-gpu installed
|
113
|
+
if torch.cuda.is_available() and "CUDAExecutionProvider" in ort.get_available_providers():
|
114
|
+
return [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"]
|
115
|
+
|
116
|
+
# Default when CUDA provider isn't available
|
117
|
+
return ["CPUExecutionProvider"]
|
118
|
+
|
119
|
+
def hasfile(self, path, name):
|
120
|
+
"""
|
121
|
+
Tests if a file exists in a local or remote repo.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
path: model path
|
125
|
+
name: file name
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
True if name exists in path, False otherwise
|
129
|
+
"""
|
130
|
+
|
131
|
+
exists = False
|
132
|
+
try:
|
133
|
+
# Check if file exists
|
134
|
+
exists = cached_file(path_or_repo_id=path, filename=name) is not None
|
135
|
+
except (HFValidationError, OSError):
|
136
|
+
return False
|
137
|
+
|
138
|
+
return exists
|
139
|
+
|
140
|
+
def stream(self, texts, speaker, encoding):
|
141
|
+
"""
|
142
|
+
Iterates over texts, splits into segments and yields snippets of audio.
|
143
|
+
This method is designed to integrate with streaming LLM generation.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
texts: list of input texts
|
147
|
+
speaker: speaker id
|
148
|
+
encoding: audio encoding format
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
snippets of audio as NumPy arrays or audio bytes depending on encoding parameter
|
152
|
+
"""
|
153
|
+
|
154
|
+
buffer = []
|
155
|
+
for x in texts:
|
156
|
+
buffer.append(x)
|
157
|
+
|
158
|
+
if x == "\n" or (x.strip().endswith(".") and len([y for y in buffer if y]) > 2):
|
159
|
+
data, buffer = "".join(buffer), []
|
160
|
+
yield self.execute(data, speaker, encoding)
|
161
|
+
|
162
|
+
if buffer:
|
163
|
+
data = "".join(buffer)
|
164
|
+
yield self.execute(data, speaker, encoding)
|
165
|
+
|
166
|
+
def execute(self, text, speaker, encoding, **kwargs):
|
167
|
+
"""
|
168
|
+
Executes model run for an input array of tokens. This method will build batches
|
169
|
+
of tokens when len(tokens) > maxtokens.
|
170
|
+
|
171
|
+
Args:
|
172
|
+
text: text to tokenize and pass to model
|
173
|
+
speaker: speaker id
|
174
|
+
encoding: audio encoding format
|
175
|
+
kwargs: additional keyword args
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
(audio, sample rate) or audio bytes depending on encoding parameter
|
179
|
+
"""
|
180
|
+
|
181
|
+
# Run pipeline model
|
182
|
+
audio, rate = self.pipeline(text, speaker, **kwargs)
|
183
|
+
|
184
|
+
# Resample, if necessary and return
|
185
|
+
audio, rate = (Signal.resample(audio, rate, self.rate), self.rate) if self.rate else (audio, rate)
|
186
|
+
|
187
|
+
# Encoding audio data
|
188
|
+
if encoding:
|
189
|
+
data = BytesIO()
|
190
|
+
sf.write(data, audio, rate, format=encoding)
|
191
|
+
return data.getvalue()
|
192
|
+
|
193
|
+
# Default to (audio, rate) tuple
|
194
|
+
return (audio, rate)
|
195
|
+
|
196
|
+
|
197
|
+
class SpeechPipeline(Pipeline):
|
198
|
+
"""
|
199
|
+
Base class for speech pipelines
|
200
|
+
"""
|
201
|
+
|
202
|
+
# pylint: disable=W0221
|
203
|
+
def chunk(self, data, size, punctids):
|
204
|
+
"""
|
205
|
+
Batching method that takes punctuation into account. This method splits data up to size
|
206
|
+
chunks. But it also searches the batch and splits on the last punctuation token id.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
data: data
|
210
|
+
size: batch size
|
211
|
+
punctids: list of punctuation token ids
|
212
|
+
|
213
|
+
Returns:
|
214
|
+
yields batches of data
|
215
|
+
"""
|
216
|
+
|
217
|
+
# Iterate over each token
|
218
|
+
punct, index = 0, 0
|
219
|
+
for i, x in enumerate(data):
|
220
|
+
# Check if token is a punctuation token
|
221
|
+
if x in punctids:
|
222
|
+
punct = i
|
223
|
+
|
224
|
+
# Batch size reached, leave a spot for the punctuation token
|
225
|
+
if i - index >= (size - 1):
|
226
|
+
end = (punct if punct > index else i) + 1
|
227
|
+
yield data[index:end]
|
228
|
+
index = end
|
229
|
+
|
230
|
+
# Last batch
|
231
|
+
if index < len(data):
|
232
|
+
yield data[index : len(data)]
|
233
|
+
|
234
|
+
|
235
|
+
class ESPnet(SpeechPipeline):
|
236
|
+
"""
|
237
|
+
Text to Speech pipeline with an ESPnet ONNX model.
|
238
|
+
"""
|
239
|
+
|
240
|
+
def __init__(self, path, maxtokens, providers):
|
241
|
+
"""
|
242
|
+
Creates a new ESPnet pipeline.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
path: model path
|
246
|
+
maxtokens: maximum number of tokens model can process
|
247
|
+
providers: list of supported ONNX providers
|
248
|
+
"""
|
249
|
+
|
250
|
+
# Get path to model and config
|
251
|
+
config = cached_file(path_or_repo_id=path, filename="config.yaml")
|
252
|
+
model = cached_file(path_or_repo_id=path, filename="model.onnx")
|
253
|
+
|
254
|
+
# Read yaml config
|
255
|
+
with open(config, "r", encoding="utf-8") as f:
|
256
|
+
config = yaml.safe_load(f)
|
257
|
+
|
258
|
+
# Create tokenizer
|
259
|
+
tokens = config.get("token", {}).get("list")
|
260
|
+
self.tokenizer = TTSTokenizer(tokens)
|
261
|
+
|
262
|
+
# Create ONNX Session
|
263
|
+
self.model = ort.InferenceSession(model, ort.SessionOptions(), providers)
|
264
|
+
|
265
|
+
# Max number of input tokens model can handle
|
266
|
+
self.maxtokens = maxtokens
|
267
|
+
|
268
|
+
# Get model input name, typically "text"
|
269
|
+
self.input = self.model.get_inputs()[0].name
|
270
|
+
|
271
|
+
# Get parameter names
|
272
|
+
self.params = set(x.name for x in self.model.get_inputs())
|
273
|
+
|
274
|
+
def __call__(self, text, speaker):
|
275
|
+
"""
|
276
|
+
Executes a model run. This method will build batches of tokens when len(tokens) > maxtokens.
|
277
|
+
|
278
|
+
Args:
|
279
|
+
text: text to tokenize and pass to model
|
280
|
+
speaker: speaker id
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
(audio, sample rate)
|
284
|
+
"""
|
285
|
+
|
286
|
+
# Debug logging for input text
|
287
|
+
logger.debug("%s", text)
|
288
|
+
|
289
|
+
# Sample rate
|
290
|
+
rate = 22050
|
291
|
+
|
292
|
+
# Tokenize input
|
293
|
+
tokens = self.tokenizer(text)
|
294
|
+
|
295
|
+
# Split into batches and process
|
296
|
+
results = []
|
297
|
+
for i, x in enumerate(self.chunk(tokens, self.maxtokens, self.tokenizer.punctuation())):
|
298
|
+
# Format input parameters
|
299
|
+
params = {self.input: x}
|
300
|
+
params = {**params, **{"sids": np.array([speaker])}} if "sids" in self.params else params
|
301
|
+
|
302
|
+
# Run text through TTS model and save waveform
|
303
|
+
output = self.model.run(None, params)
|
304
|
+
results.append(Signal.trim(output[0], rate, trailing=False) if i > 0 else output[0])
|
305
|
+
|
306
|
+
# Concatenate results and return
|
307
|
+
return (np.concatenate(results), rate)
|
308
|
+
|
309
|
+
|
310
|
+
class Kokoro(SpeechPipeline):
|
311
|
+
"""
|
312
|
+
Text to Speech pipeline with an Kokoro ONNX model.
|
313
|
+
"""
|
314
|
+
|
315
|
+
def __init__(self, path, maxtokens, providers):
|
316
|
+
"""
|
317
|
+
Creates a new Kokoro pipeline.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
path: model path
|
321
|
+
maxtokens: maximum number of tokens model can process
|
322
|
+
providers: list of supported ONNX providers
|
323
|
+
"""
|
324
|
+
|
325
|
+
# Get path to model and config
|
326
|
+
voices = cached_file(path_or_repo_id=path, filename="voices.json")
|
327
|
+
model = cached_file(path_or_repo_id=path, filename="model.onnx")
|
328
|
+
|
329
|
+
# Read voices config
|
330
|
+
with open(voices, "r", encoding="utf-8") as f:
|
331
|
+
self.voices = json.load(f)
|
332
|
+
|
333
|
+
# Create tokenizer
|
334
|
+
self.tokenizer = IPATokenizer()
|
335
|
+
|
336
|
+
# Create ONNX Session
|
337
|
+
self.model = ort.InferenceSession(model, ort.SessionOptions(), providers)
|
338
|
+
|
339
|
+
# Max number of input tokens model can handle
|
340
|
+
self.maxtokens = min(maxtokens, 510)
|
341
|
+
|
342
|
+
# Get model input name
|
343
|
+
self.input = self.model.get_inputs()[0].name
|
344
|
+
|
345
|
+
# Get parameter names
|
346
|
+
self.params = set(x.name for x in self.model.get_inputs())
|
347
|
+
|
348
|
+
def __call__(self, text, speaker=None, speed=1.0, transcribe=True):
|
349
|
+
"""
|
350
|
+
Executes a model run. This method will build batches of tokens when len(tokens) > maxtokens.
|
351
|
+
|
352
|
+
Args:
|
353
|
+
text: text to tokenize and pass to model
|
354
|
+
speaker: speaker id, defaults to first speaker
|
355
|
+
speed: defaults to 1.0
|
356
|
+
transcribe: if text should be transcriped to IPA text, defaults to True
|
357
|
+
|
358
|
+
Returns:
|
359
|
+
(audio, sample rate)
|
360
|
+
"""
|
361
|
+
|
362
|
+
# Debug logging for input text
|
363
|
+
logger.debug("%s", text)
|
364
|
+
|
365
|
+
# Sample rate
|
366
|
+
rate = 24000
|
367
|
+
|
368
|
+
# Looks up speaker, falls back to default
|
369
|
+
speaker = speaker if speaker in self.voices else next(iter(self.voices))
|
370
|
+
speaker = np.array(self.voices[speaker], dtype=np.float32)
|
371
|
+
|
372
|
+
# Tokenize input
|
373
|
+
self.tokenizer.transcribe = transcribe
|
374
|
+
tokens = self.tokenizer(text)
|
375
|
+
|
376
|
+
# Split into batches and process
|
377
|
+
results = []
|
378
|
+
for i, x in enumerate(self.chunk(tokens, self.maxtokens, self.tokenizer.punctuation())):
|
379
|
+
# Format input parameters
|
380
|
+
params = {self.input: [[0, *x, 0]], "style": speaker[len(x)], "speed": np.ones(1, dtype=np.float32) * speed}
|
381
|
+
|
382
|
+
# Run text through TTS model and save waveform
|
383
|
+
output = self.model.run(None, params)
|
384
|
+
results.append(Signal.trim(output[0], rate, trailing=False) if i > 0 else output[0])
|
385
|
+
|
386
|
+
# Concatenate results and return
|
387
|
+
return (np.concatenate(results), rate)
|
388
|
+
|
389
|
+
|
390
|
+
class SpeechT5(SpeechPipeline):
|
391
|
+
"""
|
392
|
+
Text to Speech pipeline with a SpeechT5 ONNX model.
|
393
|
+
"""
|
394
|
+
|
395
|
+
def __init__(self, path, maxtokens, providers):
|
396
|
+
"""
|
397
|
+
Creates a new SpeechT5 pipeline.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
path: model path
|
401
|
+
maxtokens: maximum number of tokens model can process
|
402
|
+
providers: list of supported ONNX providers
|
403
|
+
"""
|
404
|
+
|
405
|
+
self.encoder = ort.InferenceSession(cached_file(path_or_repo_id=path, filename="encoder_model.onnx"), providers=providers)
|
406
|
+
self.decoder = ort.InferenceSession(cached_file(path_or_repo_id=path, filename="decoder_model_merged.onnx"), providers=providers)
|
407
|
+
self.vocoder = ort.InferenceSession(cached_file(path_or_repo_id=path, filename="decoder_postnet_and_vocoder.onnx"), providers=providers)
|
408
|
+
|
409
|
+
self.processor = SpeechT5Processor.from_pretrained(path)
|
410
|
+
self.defaultspeaker = np.load(cached_file(path_or_repo_id=path, filename="speaker.npy"), allow_pickle=False)
|
411
|
+
|
412
|
+
# Max number of input tokens model can handle
|
413
|
+
self.maxtokens = maxtokens
|
414
|
+
|
415
|
+
# pylint: disable=E1101
|
416
|
+
# Punctuation token ids
|
417
|
+
self.punctids = [v for k, v in self.processor.tokenizer.get_vocab().items() if k in ".,!?;"]
|
418
|
+
|
419
|
+
def __call__(self, text, speaker):
|
420
|
+
"""
|
421
|
+
Executes a model run. This method will build batches of tokens when len(tokens) > maxtokens.
|
422
|
+
|
423
|
+
Args:
|
424
|
+
text: text to tokenize and pass to model
|
425
|
+
speaker: speaker embeddings
|
426
|
+
|
427
|
+
Returns:
|
428
|
+
(audio, sample rate)
|
429
|
+
"""
|
430
|
+
|
431
|
+
# Debug logging for input text
|
432
|
+
logger.debug("%s", text)
|
433
|
+
|
434
|
+
# Sample rate
|
435
|
+
rate = 16000
|
436
|
+
|
437
|
+
# Tokenize text
|
438
|
+
inputs = self.processor(text=text, return_tensors="np", normalize=True)
|
439
|
+
|
440
|
+
# Split into batches and process
|
441
|
+
results = []
|
442
|
+
for i, x in enumerate(self.chunk(inputs["input_ids"][0], self.maxtokens, self.punctids)):
|
443
|
+
# Run text through TTS model and save waveform
|
444
|
+
chunk = self.process(np.array([x], dtype=np.int64), speaker)
|
445
|
+
results.append(Signal.trim(chunk, rate, trailing=False) if i > 0 else chunk)
|
446
|
+
|
447
|
+
# Concatenate results and return
|
448
|
+
return (np.concatenate(results), rate)
|
449
|
+
|
450
|
+
def process(self, inputs, speaker):
|
451
|
+
"""
|
452
|
+
Runs model inference.
|
453
|
+
|
454
|
+
Args:
|
455
|
+
inputs: input token ids
|
456
|
+
speaker: speaker embeddings
|
457
|
+
|
458
|
+
Returns:
|
459
|
+
waveform as NumPy array
|
460
|
+
"""
|
461
|
+
|
462
|
+
# Run through encoder model
|
463
|
+
outputs = self.encoder.run(None, {"input_ids": inputs})
|
464
|
+
outputs = {key.name: outputs[x] for x, key in enumerate(self.encoder.get_outputs())}
|
465
|
+
|
466
|
+
# Encoder outputs and parameters
|
467
|
+
hiddenstate, attentionmask = outputs["encoder_outputs"], outputs["encoder_attention_mask"]
|
468
|
+
minlenratio, maxlenratio = 0.0, 20.0
|
469
|
+
reduction, threshold, melbins = 2, 0.5, 80
|
470
|
+
|
471
|
+
maxlen = int(hiddenstate.shape[1] * maxlenratio / reduction)
|
472
|
+
minlen = int(hiddenstate.shape[1] * minlenratio / reduction)
|
473
|
+
|
474
|
+
# Main processing loop
|
475
|
+
spectrogram, index, crossattention, branch, outputs = [], 0, None, False, {}
|
476
|
+
while True:
|
477
|
+
index += 1
|
478
|
+
|
479
|
+
inputs = {
|
480
|
+
"use_cache_branch": np.array([branch]),
|
481
|
+
"encoder_attention_mask": attentionmask,
|
482
|
+
"speaker_embeddings": speaker if speaker is not None and isinstance(speaker, np.ndarray) else self.defaultspeaker,
|
483
|
+
}
|
484
|
+
|
485
|
+
if index == 1:
|
486
|
+
inputs = self.placeholders(inputs)
|
487
|
+
inputs["output_sequence"] = np.zeros((1, 1, melbins)).astype(np.float32)
|
488
|
+
inputs["encoder_hidden_states"] = hiddenstate
|
489
|
+
branch = True
|
490
|
+
else:
|
491
|
+
inputs = self.inputs(inputs, outputs, crossattention)
|
492
|
+
inputs["output_sequence"] = outputs["output_sequence_out"]
|
493
|
+
inputs["encoder_hidden_states"] = np.zeros((1, 0, 768)).astype(np.float32)
|
494
|
+
|
495
|
+
# Run inputs through decoder
|
496
|
+
outputs = self.decoder.run(None, inputs)
|
497
|
+
outputs = {key.name: outputs[x] for x, key in enumerate(self.decoder.get_outputs())}
|
498
|
+
|
499
|
+
# Get cross attention with 1st pass
|
500
|
+
if index == 1:
|
501
|
+
crossattention = {key: val for key, val in outputs.items() if ("encoder" in key and "present" in key)}
|
502
|
+
|
503
|
+
# Decoder outputs
|
504
|
+
prob = outputs["prob"]
|
505
|
+
spectrum = outputs["spectrum"]
|
506
|
+
spectrogram.append(spectrum)
|
507
|
+
|
508
|
+
# Done when stop token or maximum length is reached.
|
509
|
+
if index >= minlen and (int(sum(prob >= threshold)) > 0 or index >= maxlen):
|
510
|
+
spectrogram = np.concatenate(spectrogram)
|
511
|
+
return self.vocoder.run(None, {"spectrogram": spectrogram})[0]
|
512
|
+
|
513
|
+
def placeholders(self, inputs):
|
514
|
+
"""
|
515
|
+
Creates decoder model inputs for initial inference pass.
|
516
|
+
|
517
|
+
Args:
|
518
|
+
inputs: current decoder inputs
|
519
|
+
|
520
|
+
Returns:
|
521
|
+
updated decoder inputs
|
522
|
+
"""
|
523
|
+
|
524
|
+
length = inputs["encoder_attention_mask"].shape[1]
|
525
|
+
|
526
|
+
for x in range(6):
|
527
|
+
inputs[f"past_key_values.{x}.encoder.key"] = np.zeros((1, 12, length, 64)).astype(np.float32)
|
528
|
+
inputs[f"past_key_values.{x}.encoder.value"] = np.zeros((1, 12, length, 64)).astype(np.float32)
|
529
|
+
inputs[f"past_key_values.{x}.decoder.key"] = np.zeros((1, 12, 1, 64)).astype(np.float32)
|
530
|
+
inputs[f"past_key_values.{x}.decoder.value"] = np.zeros((1, 12, 1, 64)).astype(np.float32)
|
531
|
+
|
532
|
+
return inputs
|
533
|
+
|
534
|
+
def inputs(self, inputs, previous, crossattention):
|
535
|
+
"""
|
536
|
+
Creates decoder model inputs for follow-on inference passes.
|
537
|
+
|
538
|
+
Args:
|
539
|
+
inputs: current decoder inputs
|
540
|
+
previous: previous decoder outputs
|
541
|
+
crossattention: crossattention parameters
|
542
|
+
|
543
|
+
Returns:
|
544
|
+
updated decoder inputs
|
545
|
+
"""
|
546
|
+
|
547
|
+
for x in range(6):
|
548
|
+
inputs[f"past_key_values.{x}.encoder.key"] = crossattention[f"present.{x}.encoder.key"]
|
549
|
+
inputs[f"past_key_values.{x}.encoder.value"] = crossattention[f"present.{x}.encoder.value"]
|
550
|
+
inputs[f"past_key_values.{x}.decoder.key"] = previous[f"present.{x}.decoder.key"]
|
551
|
+
inputs[f"past_key_values.{x}.decoder.value"] = previous[f"present.{x}.decoder.value"]
|
552
|
+
|
553
|
+
return inputs
|