mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,553 @@
1
+ """
2
+ TextToSpeech module
3
+ """
4
+
5
+ # Conditional import
6
+ try:
7
+ import onnxruntime as ort
8
+ import soundfile as sf
9
+
10
+ from ttstokenizer import IPATokenizer, TTSTokenizer
11
+
12
+ from .signal import Signal, SCIPY
13
+
14
+ TTS = SCIPY
15
+ except ImportError:
16
+ TTS = False
17
+
18
+ import json
19
+ import logging
20
+
21
+ from io import BytesIO
22
+
23
+ import torch
24
+ import yaml
25
+
26
+ import numpy as np
27
+
28
+ from huggingface_hub.errors import HFValidationError
29
+ from transformers import SpeechT5Processor
30
+ from transformers.utils import cached_file
31
+
32
+ from ..base import Pipeline
33
+
34
+ # Logging configuration
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class TextToSpeech(Pipeline):
39
+ """
40
+ Generates speech from text.
41
+ """
42
+
43
+ def __init__(self, path=None, maxtokens=512, rate=22050):
44
+ """
45
+ Creates a new TextToSpeech pipeline.
46
+
47
+ Args:
48
+ path: optional model path
49
+ maxtokens: maximum number of tokens model can process, defaults to 512
50
+ rate: target sample rate, defaults to 22050
51
+ """
52
+
53
+ if not TTS:
54
+ raise ImportError('TextToSpeech pipeline is not available - install "pipeline" extra to enable')
55
+
56
+ # Default path
57
+ path = path if path else "neuml/ljspeech-jets-onnx"
58
+
59
+ # Target sample rate
60
+ self.rate = rate
61
+
62
+ # Load target tts pipeline
63
+ self.pipeline = None
64
+ if self.hasfile(path, "model.onnx") and self.hasfile(path, "config.yaml"):
65
+ self.pipeline = ESPnet(path, maxtokens, self.providers())
66
+ elif self.hasfile(path, "model.onnx") and self.hasfile(path, "voices.json"):
67
+ self.pipeline = Kokoro(path, maxtokens, self.providers())
68
+ else:
69
+ self.pipeline = SpeechT5(path, maxtokens, self.providers())
70
+
71
+ def __call__(self, text, stream=False, speaker=1, encoding=None, **kwargs):
72
+ """
73
+ Generates speech from text. Text longer than maxtokens will be batched and returned
74
+ as a single waveform per text input.
75
+
76
+ This method supports text as a string or a list. If the input is a string,
77
+ the return type is audio. If text is a list, the return type is a list.
78
+
79
+ Args:
80
+ text: text|list
81
+ stream: stream response if True, defaults to False
82
+ speaker: speaker id, defaults to 1
83
+ encoding: optional audio encoding format
84
+ kwargs: additional keyword args
85
+
86
+ Returns:
87
+ list of (audio, sample rate) or list of audio depending on encoding parameter
88
+ """
89
+
90
+ # Convert results to a list if necessary
91
+ texts = [text] if isinstance(text, str) else text
92
+
93
+ # Streaming response
94
+ if stream:
95
+ return self.stream(texts, speaker, encoding)
96
+
97
+ # Transform text to speech
98
+ results = [self.execute(x, speaker, encoding, **kwargs) for x in texts]
99
+
100
+ # Return results
101
+ return results[0] if isinstance(text, str) else results
102
+
103
+ def providers(self):
104
+ """
105
+ Returns a list of available and usable providers.
106
+
107
+ Returns:
108
+ list of available and usable providers
109
+ """
110
+
111
+ # Create list of providers, prefer CUDA provider if available
112
+ # CUDA provider only available if GPU is available and onnxruntime-gpu installed
113
+ if torch.cuda.is_available() and "CUDAExecutionProvider" in ort.get_available_providers():
114
+ return [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"]
115
+
116
+ # Default when CUDA provider isn't available
117
+ return ["CPUExecutionProvider"]
118
+
119
+ def hasfile(self, path, name):
120
+ """
121
+ Tests if a file exists in a local or remote repo.
122
+
123
+ Args:
124
+ path: model path
125
+ name: file name
126
+
127
+ Returns:
128
+ True if name exists in path, False otherwise
129
+ """
130
+
131
+ exists = False
132
+ try:
133
+ # Check if file exists
134
+ exists = cached_file(path_or_repo_id=path, filename=name) is not None
135
+ except (HFValidationError, OSError):
136
+ return False
137
+
138
+ return exists
139
+
140
+ def stream(self, texts, speaker, encoding):
141
+ """
142
+ Iterates over texts, splits into segments and yields snippets of audio.
143
+ This method is designed to integrate with streaming LLM generation.
144
+
145
+ Args:
146
+ texts: list of input texts
147
+ speaker: speaker id
148
+ encoding: audio encoding format
149
+
150
+ Returns:
151
+ snippets of audio as NumPy arrays or audio bytes depending on encoding parameter
152
+ """
153
+
154
+ buffer = []
155
+ for x in texts:
156
+ buffer.append(x)
157
+
158
+ if x == "\n" or (x.strip().endswith(".") and len([y for y in buffer if y]) > 2):
159
+ data, buffer = "".join(buffer), []
160
+ yield self.execute(data, speaker, encoding)
161
+
162
+ if buffer:
163
+ data = "".join(buffer)
164
+ yield self.execute(data, speaker, encoding)
165
+
166
+ def execute(self, text, speaker, encoding, **kwargs):
167
+ """
168
+ Executes model run for an input array of tokens. This method will build batches
169
+ of tokens when len(tokens) > maxtokens.
170
+
171
+ Args:
172
+ text: text to tokenize and pass to model
173
+ speaker: speaker id
174
+ encoding: audio encoding format
175
+ kwargs: additional keyword args
176
+
177
+ Returns:
178
+ (audio, sample rate) or audio bytes depending on encoding parameter
179
+ """
180
+
181
+ # Run pipeline model
182
+ audio, rate = self.pipeline(text, speaker, **kwargs)
183
+
184
+ # Resample, if necessary and return
185
+ audio, rate = (Signal.resample(audio, rate, self.rate), self.rate) if self.rate else (audio, rate)
186
+
187
+ # Encoding audio data
188
+ if encoding:
189
+ data = BytesIO()
190
+ sf.write(data, audio, rate, format=encoding)
191
+ return data.getvalue()
192
+
193
+ # Default to (audio, rate) tuple
194
+ return (audio, rate)
195
+
196
+
197
+ class SpeechPipeline(Pipeline):
198
+ """
199
+ Base class for speech pipelines
200
+ """
201
+
202
+ # pylint: disable=W0221
203
+ def chunk(self, data, size, punctids):
204
+ """
205
+ Batching method that takes punctuation into account. This method splits data up to size
206
+ chunks. But it also searches the batch and splits on the last punctuation token id.
207
+
208
+ Args:
209
+ data: data
210
+ size: batch size
211
+ punctids: list of punctuation token ids
212
+
213
+ Returns:
214
+ yields batches of data
215
+ """
216
+
217
+ # Iterate over each token
218
+ punct, index = 0, 0
219
+ for i, x in enumerate(data):
220
+ # Check if token is a punctuation token
221
+ if x in punctids:
222
+ punct = i
223
+
224
+ # Batch size reached, leave a spot for the punctuation token
225
+ if i - index >= (size - 1):
226
+ end = (punct if punct > index else i) + 1
227
+ yield data[index:end]
228
+ index = end
229
+
230
+ # Last batch
231
+ if index < len(data):
232
+ yield data[index : len(data)]
233
+
234
+
235
+ class ESPnet(SpeechPipeline):
236
+ """
237
+ Text to Speech pipeline with an ESPnet ONNX model.
238
+ """
239
+
240
+ def __init__(self, path, maxtokens, providers):
241
+ """
242
+ Creates a new ESPnet pipeline.
243
+
244
+ Args:
245
+ path: model path
246
+ maxtokens: maximum number of tokens model can process
247
+ providers: list of supported ONNX providers
248
+ """
249
+
250
+ # Get path to model and config
251
+ config = cached_file(path_or_repo_id=path, filename="config.yaml")
252
+ model = cached_file(path_or_repo_id=path, filename="model.onnx")
253
+
254
+ # Read yaml config
255
+ with open(config, "r", encoding="utf-8") as f:
256
+ config = yaml.safe_load(f)
257
+
258
+ # Create tokenizer
259
+ tokens = config.get("token", {}).get("list")
260
+ self.tokenizer = TTSTokenizer(tokens)
261
+
262
+ # Create ONNX Session
263
+ self.model = ort.InferenceSession(model, ort.SessionOptions(), providers)
264
+
265
+ # Max number of input tokens model can handle
266
+ self.maxtokens = maxtokens
267
+
268
+ # Get model input name, typically "text"
269
+ self.input = self.model.get_inputs()[0].name
270
+
271
+ # Get parameter names
272
+ self.params = set(x.name for x in self.model.get_inputs())
273
+
274
+ def __call__(self, text, speaker):
275
+ """
276
+ Executes a model run. This method will build batches of tokens when len(tokens) > maxtokens.
277
+
278
+ Args:
279
+ text: text to tokenize and pass to model
280
+ speaker: speaker id
281
+
282
+ Returns:
283
+ (audio, sample rate)
284
+ """
285
+
286
+ # Debug logging for input text
287
+ logger.debug("%s", text)
288
+
289
+ # Sample rate
290
+ rate = 22050
291
+
292
+ # Tokenize input
293
+ tokens = self.tokenizer(text)
294
+
295
+ # Split into batches and process
296
+ results = []
297
+ for i, x in enumerate(self.chunk(tokens, self.maxtokens, self.tokenizer.punctuation())):
298
+ # Format input parameters
299
+ params = {self.input: x}
300
+ params = {**params, **{"sids": np.array([speaker])}} if "sids" in self.params else params
301
+
302
+ # Run text through TTS model and save waveform
303
+ output = self.model.run(None, params)
304
+ results.append(Signal.trim(output[0], rate, trailing=False) if i > 0 else output[0])
305
+
306
+ # Concatenate results and return
307
+ return (np.concatenate(results), rate)
308
+
309
+
310
+ class Kokoro(SpeechPipeline):
311
+ """
312
+ Text to Speech pipeline with an Kokoro ONNX model.
313
+ """
314
+
315
+ def __init__(self, path, maxtokens, providers):
316
+ """
317
+ Creates a new Kokoro pipeline.
318
+
319
+ Args:
320
+ path: model path
321
+ maxtokens: maximum number of tokens model can process
322
+ providers: list of supported ONNX providers
323
+ """
324
+
325
+ # Get path to model and config
326
+ voices = cached_file(path_or_repo_id=path, filename="voices.json")
327
+ model = cached_file(path_or_repo_id=path, filename="model.onnx")
328
+
329
+ # Read voices config
330
+ with open(voices, "r", encoding="utf-8") as f:
331
+ self.voices = json.load(f)
332
+
333
+ # Create tokenizer
334
+ self.tokenizer = IPATokenizer()
335
+
336
+ # Create ONNX Session
337
+ self.model = ort.InferenceSession(model, ort.SessionOptions(), providers)
338
+
339
+ # Max number of input tokens model can handle
340
+ self.maxtokens = min(maxtokens, 510)
341
+
342
+ # Get model input name
343
+ self.input = self.model.get_inputs()[0].name
344
+
345
+ # Get parameter names
346
+ self.params = set(x.name for x in self.model.get_inputs())
347
+
348
+ def __call__(self, text, speaker=None, speed=1.0, transcribe=True):
349
+ """
350
+ Executes a model run. This method will build batches of tokens when len(tokens) > maxtokens.
351
+
352
+ Args:
353
+ text: text to tokenize and pass to model
354
+ speaker: speaker id, defaults to first speaker
355
+ speed: defaults to 1.0
356
+ transcribe: if text should be transcriped to IPA text, defaults to True
357
+
358
+ Returns:
359
+ (audio, sample rate)
360
+ """
361
+
362
+ # Debug logging for input text
363
+ logger.debug("%s", text)
364
+
365
+ # Sample rate
366
+ rate = 24000
367
+
368
+ # Looks up speaker, falls back to default
369
+ speaker = speaker if speaker in self.voices else next(iter(self.voices))
370
+ speaker = np.array(self.voices[speaker], dtype=np.float32)
371
+
372
+ # Tokenize input
373
+ self.tokenizer.transcribe = transcribe
374
+ tokens = self.tokenizer(text)
375
+
376
+ # Split into batches and process
377
+ results = []
378
+ for i, x in enumerate(self.chunk(tokens, self.maxtokens, self.tokenizer.punctuation())):
379
+ # Format input parameters
380
+ params = {self.input: [[0, *x, 0]], "style": speaker[len(x)], "speed": np.ones(1, dtype=np.float32) * speed}
381
+
382
+ # Run text through TTS model and save waveform
383
+ output = self.model.run(None, params)
384
+ results.append(Signal.trim(output[0], rate, trailing=False) if i > 0 else output[0])
385
+
386
+ # Concatenate results and return
387
+ return (np.concatenate(results), rate)
388
+
389
+
390
+ class SpeechT5(SpeechPipeline):
391
+ """
392
+ Text to Speech pipeline with a SpeechT5 ONNX model.
393
+ """
394
+
395
+ def __init__(self, path, maxtokens, providers):
396
+ """
397
+ Creates a new SpeechT5 pipeline.
398
+
399
+ Args:
400
+ path: model path
401
+ maxtokens: maximum number of tokens model can process
402
+ providers: list of supported ONNX providers
403
+ """
404
+
405
+ self.encoder = ort.InferenceSession(cached_file(path_or_repo_id=path, filename="encoder_model.onnx"), providers=providers)
406
+ self.decoder = ort.InferenceSession(cached_file(path_or_repo_id=path, filename="decoder_model_merged.onnx"), providers=providers)
407
+ self.vocoder = ort.InferenceSession(cached_file(path_or_repo_id=path, filename="decoder_postnet_and_vocoder.onnx"), providers=providers)
408
+
409
+ self.processor = SpeechT5Processor.from_pretrained(path)
410
+ self.defaultspeaker = np.load(cached_file(path_or_repo_id=path, filename="speaker.npy"), allow_pickle=False)
411
+
412
+ # Max number of input tokens model can handle
413
+ self.maxtokens = maxtokens
414
+
415
+ # pylint: disable=E1101
416
+ # Punctuation token ids
417
+ self.punctids = [v for k, v in self.processor.tokenizer.get_vocab().items() if k in ".,!?;"]
418
+
419
+ def __call__(self, text, speaker):
420
+ """
421
+ Executes a model run. This method will build batches of tokens when len(tokens) > maxtokens.
422
+
423
+ Args:
424
+ text: text to tokenize and pass to model
425
+ speaker: speaker embeddings
426
+
427
+ Returns:
428
+ (audio, sample rate)
429
+ """
430
+
431
+ # Debug logging for input text
432
+ logger.debug("%s", text)
433
+
434
+ # Sample rate
435
+ rate = 16000
436
+
437
+ # Tokenize text
438
+ inputs = self.processor(text=text, return_tensors="np", normalize=True)
439
+
440
+ # Split into batches and process
441
+ results = []
442
+ for i, x in enumerate(self.chunk(inputs["input_ids"][0], self.maxtokens, self.punctids)):
443
+ # Run text through TTS model and save waveform
444
+ chunk = self.process(np.array([x], dtype=np.int64), speaker)
445
+ results.append(Signal.trim(chunk, rate, trailing=False) if i > 0 else chunk)
446
+
447
+ # Concatenate results and return
448
+ return (np.concatenate(results), rate)
449
+
450
+ def process(self, inputs, speaker):
451
+ """
452
+ Runs model inference.
453
+
454
+ Args:
455
+ inputs: input token ids
456
+ speaker: speaker embeddings
457
+
458
+ Returns:
459
+ waveform as NumPy array
460
+ """
461
+
462
+ # Run through encoder model
463
+ outputs = self.encoder.run(None, {"input_ids": inputs})
464
+ outputs = {key.name: outputs[x] for x, key in enumerate(self.encoder.get_outputs())}
465
+
466
+ # Encoder outputs and parameters
467
+ hiddenstate, attentionmask = outputs["encoder_outputs"], outputs["encoder_attention_mask"]
468
+ minlenratio, maxlenratio = 0.0, 20.0
469
+ reduction, threshold, melbins = 2, 0.5, 80
470
+
471
+ maxlen = int(hiddenstate.shape[1] * maxlenratio / reduction)
472
+ minlen = int(hiddenstate.shape[1] * minlenratio / reduction)
473
+
474
+ # Main processing loop
475
+ spectrogram, index, crossattention, branch, outputs = [], 0, None, False, {}
476
+ while True:
477
+ index += 1
478
+
479
+ inputs = {
480
+ "use_cache_branch": np.array([branch]),
481
+ "encoder_attention_mask": attentionmask,
482
+ "speaker_embeddings": speaker if speaker is not None and isinstance(speaker, np.ndarray) else self.defaultspeaker,
483
+ }
484
+
485
+ if index == 1:
486
+ inputs = self.placeholders(inputs)
487
+ inputs["output_sequence"] = np.zeros((1, 1, melbins)).astype(np.float32)
488
+ inputs["encoder_hidden_states"] = hiddenstate
489
+ branch = True
490
+ else:
491
+ inputs = self.inputs(inputs, outputs, crossattention)
492
+ inputs["output_sequence"] = outputs["output_sequence_out"]
493
+ inputs["encoder_hidden_states"] = np.zeros((1, 0, 768)).astype(np.float32)
494
+
495
+ # Run inputs through decoder
496
+ outputs = self.decoder.run(None, inputs)
497
+ outputs = {key.name: outputs[x] for x, key in enumerate(self.decoder.get_outputs())}
498
+
499
+ # Get cross attention with 1st pass
500
+ if index == 1:
501
+ crossattention = {key: val for key, val in outputs.items() if ("encoder" in key and "present" in key)}
502
+
503
+ # Decoder outputs
504
+ prob = outputs["prob"]
505
+ spectrum = outputs["spectrum"]
506
+ spectrogram.append(spectrum)
507
+
508
+ # Done when stop token or maximum length is reached.
509
+ if index >= minlen and (int(sum(prob >= threshold)) > 0 or index >= maxlen):
510
+ spectrogram = np.concatenate(spectrogram)
511
+ return self.vocoder.run(None, {"spectrogram": spectrogram})[0]
512
+
513
+ def placeholders(self, inputs):
514
+ """
515
+ Creates decoder model inputs for initial inference pass.
516
+
517
+ Args:
518
+ inputs: current decoder inputs
519
+
520
+ Returns:
521
+ updated decoder inputs
522
+ """
523
+
524
+ length = inputs["encoder_attention_mask"].shape[1]
525
+
526
+ for x in range(6):
527
+ inputs[f"past_key_values.{x}.encoder.key"] = np.zeros((1, 12, length, 64)).astype(np.float32)
528
+ inputs[f"past_key_values.{x}.encoder.value"] = np.zeros((1, 12, length, 64)).astype(np.float32)
529
+ inputs[f"past_key_values.{x}.decoder.key"] = np.zeros((1, 12, 1, 64)).astype(np.float32)
530
+ inputs[f"past_key_values.{x}.decoder.value"] = np.zeros((1, 12, 1, 64)).astype(np.float32)
531
+
532
+ return inputs
533
+
534
+ def inputs(self, inputs, previous, crossattention):
535
+ """
536
+ Creates decoder model inputs for follow-on inference passes.
537
+
538
+ Args:
539
+ inputs: current decoder inputs
540
+ previous: previous decoder outputs
541
+ crossattention: crossattention parameters
542
+
543
+ Returns:
544
+ updated decoder inputs
545
+ """
546
+
547
+ for x in range(6):
548
+ inputs[f"past_key_values.{x}.encoder.key"] = crossattention[f"present.{x}.encoder.key"]
549
+ inputs[f"past_key_values.{x}.encoder.value"] = crossattention[f"present.{x}.encoder.value"]
550
+ inputs[f"past_key_values.{x}.decoder.key"] = previous[f"present.{x}.decoder.key"]
551
+ inputs[f"past_key_values.{x}.decoder.value"] = previous[f"present.{x}.decoder.value"]
552
+
553
+ return inputs