mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,37 @@
1
+ """
2
+ Registry module
3
+ """
4
+
5
+ from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification
6
+ from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
7
+
8
+
9
+ class Registry:
10
+ """
11
+ Methods to register models and fully support pipelines.
12
+ """
13
+
14
+ @staticmethod
15
+ def register(model, config=None):
16
+ """
17
+ Registers a model with auto model and tokenizer configuration to fully support pipelines.
18
+
19
+ Args:
20
+ model: model to register
21
+ config: config class name
22
+ """
23
+
24
+ # Default config class to model class if not provided
25
+ config = config if config else model.__class__
26
+
27
+ # Default model config_class if empty
28
+ if hasattr(model.__class__, "config_class") and not model.__class__.config_class:
29
+ model.__class__.config_class = config
30
+
31
+ # Add references for this class to supported AutoModel classes
32
+ for mapping in [AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification]:
33
+ mapping.register(config, model.__class__)
34
+
35
+ # Add references for this class to support pipeline AutoTokenizers
36
+ if hasattr(model, "config") and type(model.config) not in TOKENIZER_MAPPING:
37
+ TOKENIZER_MAPPING.register(type(model.config), type(model.config).__name__)
@@ -0,0 +1,122 @@
1
+ """
2
+ Token Detection module
3
+ """
4
+
5
+ import inspect
6
+ import os
7
+
8
+ import torch
9
+
10
+ from transformers import PreTrainedModel
11
+
12
+
13
+ class TokenDetection(PreTrainedModel):
14
+ """
15
+ Runs the replaced token detection training objective. This method was first proposed by the ELECTRA model.
16
+ The method consists of a masked language model generator feeding data to a discriminator that determines
17
+ which of the tokens are incorrect. More on this training objective can be found in the ELECTRA paper.
18
+ """
19
+
20
+ def __init__(self, generator, discriminator, tokenizer, weight=50.0):
21
+ """
22
+ Creates a new TokenDetection class.
23
+
24
+ Args:
25
+ generator: Generator model, must be a masked language model
26
+ discriminator: Discriminator model, must be a model that can detect replaced tokens. Any model can
27
+ can be customized for this task. See ElectraForPretraining for more.
28
+ """
29
+
30
+ # Initialize model with discriminator config
31
+ super().__init__(discriminator.config)
32
+
33
+ self.generator = generator
34
+ self.discriminator = discriminator
35
+
36
+ # Tokenizer to save with generator and discriminator
37
+ self.tokenizer = tokenizer
38
+
39
+ # Discriminator weight
40
+ self.weight = weight
41
+
42
+ # Share embeddings if both models are the same type
43
+ # Embeddings must be same size
44
+ if self.generator.config.model_type == self.discriminator.config.model_type:
45
+ self.discriminator.set_input_embeddings(self.generator.get_input_embeddings())
46
+
47
+ # Set attention mask present flags
48
+ self.gattention = "attention_mask" in inspect.signature(self.generator.forward).parameters
49
+ self.dattention = "attention_mask" in inspect.signature(self.discriminator.forward).parameters
50
+
51
+ # pylint: disable=E1101
52
+ def forward(self, input_ids=None, labels=None, attention_mask=None, token_type_ids=None):
53
+ """
54
+ Runs a forward pass through the model. This method runs the masked language model then randomly samples
55
+ the generated tokens and builds a binary classification problem for the discriminator (detecting if each token is correct).
56
+
57
+ Args:
58
+ input_ids: token ids
59
+ labels: token labels
60
+ attention_mask: attention mask
61
+ token_type_ids: segment token indices
62
+
63
+ Returns:
64
+ (loss, generator outputs, discriminator outputs, discriminator labels)
65
+ """
66
+
67
+ # Copy input ids
68
+ dinputs = input_ids.clone()
69
+
70
+ # Run inputs through masked language model
71
+ inputs = {"attention_mask": attention_mask} if self.gattention else {}
72
+ goutputs = self.generator(input_ids, labels=labels, token_type_ids=token_type_ids, **inputs)
73
+
74
+ # Get predictions
75
+ preds = torch.softmax(goutputs[1], dim=-1)
76
+ preds = preds.view(-1, self.config.vocab_size)
77
+
78
+ tokens = torch.multinomial(preds, 1).view(-1)
79
+ tokens = tokens.view(dinputs.shape[0], -1)
80
+
81
+ # Labels have a -100 value to ignore loss from unchanged tokens
82
+ mask = labels.ne(-100)
83
+
84
+ # Replace the masked out tokens of the input with the generator predictions
85
+ dinputs[mask] = tokens[mask]
86
+
87
+ # Turn mask into new target labels - 1 (True) for corrupted, 0 otherwise.
88
+ # If the prediction was correct, mark it as uncorrupted.
89
+ correct = tokens == labels
90
+ dlabels = mask.long()
91
+ dlabels[correct] = 0
92
+
93
+ # Run token classification, predict whether each token was corrupted
94
+ inputs = {"attention_mask": attention_mask} if self.dattention else {}
95
+ doutputs = self.discriminator(dinputs, labels=dlabels, token_type_ids=token_type_ids, **inputs)
96
+
97
+ # Compute combined loss
98
+ loss = goutputs[0] + self.weight * doutputs[0]
99
+ return loss, goutputs[1], doutputs[1], dlabels
100
+
101
+ def save_pretrained(self, output, state_dict=None, **kwargs):
102
+ """
103
+ Saves current model to output directory.
104
+
105
+ Args:
106
+ output: output directory
107
+ state_dict: model state
108
+ kwargs: additional keyword arguments
109
+ """
110
+
111
+ # Save combined model to support training from checkpoints
112
+ super().save_pretrained(output, state_dict, **kwargs)
113
+
114
+ # Save generator tokenizer and model
115
+ gpath = os.path.join(output, "generator")
116
+ self.tokenizer.save_pretrained(gpath)
117
+ self.generator.save_pretrained(gpath)
118
+
119
+ # Save discriminator tokenizer and model
120
+ dpath = os.path.join(output, "discriminator")
121
+ self.tokenizer.save_pretrained(dpath)
122
+ self.discriminator.save_pretrained(dpath)
@@ -0,0 +1,17 @@
1
+ """
2
+ Pipeline imports
3
+ """
4
+
5
+ from .audio import *
6
+ from .base import Pipeline
7
+ from .data import *
8
+ from .factory import PipelineFactory
9
+ from .hfmodel import HFModel
10
+ from .hfpipeline import HFPipeline
11
+ from .image import *
12
+ from .llm import *
13
+ from .llm import RAG as Extractor
14
+ from .nop import Nop
15
+ from .text import *
16
+ from .tensors import Tensors
17
+ from .train import *
@@ -0,0 +1,11 @@
1
+ """
2
+ Audio imports
3
+ """
4
+
5
+ from .audiomixer import AudioMixer
6
+ from .audiostream import AudioStream
7
+ from .microphone import Microphone
8
+ from .signal import Signal
9
+ from .texttoaudio import TextToAudio
10
+ from .texttospeech import TextToSpeech
11
+ from .transcription import Transcription
@@ -0,0 +1,58 @@
1
+ """
2
+ AudioMixer module
3
+ """
4
+
5
+ from ..base import Pipeline
6
+ from .signal import Signal, SCIPY
7
+
8
+
9
+ class AudioMixer(Pipeline):
10
+ """
11
+ Mixes multiple audio streams into a single stream.
12
+ """
13
+
14
+ def __init__(self, rate=None):
15
+ """
16
+ Creates an AudioMixer pipeline.
17
+
18
+ Args:
19
+ rate: optional target sample rate, otherwise uses input target rate with each audio segment
20
+ """
21
+
22
+ if not SCIPY:
23
+ raise ImportError('AudioMixer pipeline is not available - install "pipeline" extra to enable.')
24
+
25
+ # Target sample rate
26
+ self.rate = rate
27
+
28
+ def __call__(self, segment, scale1=1, scale2=1):
29
+ """
30
+ Mixes multiple audio streams into a single stream.
31
+
32
+ Args:
33
+ segment: ((audio1, sample rate), (audio2, sample rate))|list
34
+ scale1: optional scaling factor for segment1
35
+ scale2: optional scaling factor for segment2
36
+
37
+ Returns:
38
+ list of (audio, sample rate)
39
+ """
40
+
41
+ # Convert single element to list
42
+ segments = [segment] if isinstance(segment, tuple) else segment
43
+
44
+ results = []
45
+ for segment1, segment2 in segments:
46
+ audio1, rate1 = segment1
47
+ audio2, rate2 = segment2
48
+
49
+ # Resample audio, as necessary
50
+ target = self.rate if self.rate else rate1
51
+ audio1 = Signal.resample(audio1, rate1, target)
52
+ audio2 = Signal.resample(audio2, rate2, target)
53
+
54
+ # Mix audio into single segment
55
+ results.append((Signal.mix(audio1, audio2, scale1, scale2), target))
56
+
57
+ # Return single element if single element passed in
58
+ return results[0] if isinstance(segment, tuple) else results
@@ -0,0 +1,94 @@
1
+ """
2
+ AudioStream module
3
+ """
4
+
5
+ from queue import Queue
6
+ from threading import Thread
7
+
8
+ # Conditional import
9
+ try:
10
+ import sounddevice as sd
11
+
12
+ from .signal import Signal, SCIPY
13
+
14
+ AUDIOSTREAM = SCIPY
15
+ except (ImportError, OSError):
16
+ AUDIOSTREAM = False
17
+
18
+ from ..base import Pipeline
19
+
20
+
21
+ class AudioStream(Pipeline):
22
+ """
23
+ Threaded pipeline that streams audio segments to an output audio device. This pipeline is designed
24
+ to run on local machines given that it requires access to write to an output device.
25
+ """
26
+
27
+ # End of stream message
28
+ COMPLETE = (1, None)
29
+
30
+ def __init__(self, rate=None):
31
+ """
32
+ Creates an AudioStream pipeline.
33
+
34
+ Args:
35
+ rate: optional target sample rate, otherwise uses input target rate with each audio segment
36
+ """
37
+
38
+ if not AUDIOSTREAM:
39
+ raise ImportError(
40
+ (
41
+ 'AudioStream pipeline is not available - install "pipeline" extra to enable. '
42
+ "Also check that the portaudio system library is available."
43
+ )
44
+ )
45
+
46
+ # Target sample rate
47
+ self.rate = rate
48
+
49
+ self.queue = Queue()
50
+ self.thread = Thread(target=self.play)
51
+ self.thread.start()
52
+
53
+ def __call__(self, segment):
54
+ """
55
+ Queues audio segments for the audio player.
56
+
57
+ Args:
58
+ segment: (audio, sample rate)|list
59
+
60
+ Returns:
61
+ segment
62
+ """
63
+
64
+ # Convert single element to list
65
+ segments = [segment] if isinstance(segment, tuple) else segment
66
+
67
+ for x in segments:
68
+ self.queue.put(x)
69
+
70
+ # Return single element if single element passed in
71
+ return segments[0] if isinstance(segment, tuple) else segments
72
+
73
+ def wait(self):
74
+ """
75
+ Waits for all input audio segments to be played.
76
+ """
77
+
78
+ self.thread.join()
79
+
80
+ def play(self):
81
+ """
82
+ Reads audio segments from queue. This method runs in a separate non-blocking thread.
83
+ """
84
+
85
+ audio, rate = self.queue.get()
86
+ while not isinstance(audio, int) or (audio, rate) != AudioStream.COMPLETE:
87
+ # Resample to target sample rate, if necessary
88
+ audio, rate = (Signal.resample(audio, rate, self.rate), self.rate) if self.rate else (audio, rate)
89
+
90
+ # Play audio segment
91
+ sd.play(audio, rate, blocking=True)
92
+
93
+ # Get next segment
94
+ audio, rate = self.queue.get()
@@ -0,0 +1,244 @@
1
+ """
2
+ Microphone module
3
+ """
4
+
5
+ import logging
6
+
7
+ import numpy as np
8
+
9
+ # Conditional import
10
+ try:
11
+ import sounddevice as sd
12
+ import webrtcvad
13
+
14
+ from scipy.signal import butter, sosfilt
15
+
16
+ from .signal import Signal, SCIPY
17
+
18
+ MICROPHONE = SCIPY
19
+ except (ImportError, OSError):
20
+ MICROPHONE = False
21
+
22
+ from ..base import Pipeline
23
+
24
+ # Logging configuration
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class Microphone(Pipeline):
29
+ """
30
+ Reads input speech from a microphone device. This pipeline is designed to run on local machines given
31
+ that it requires access to read from an input device.
32
+ """
33
+
34
+ def __init__(self, rate=16000, vadmode=3, vadframe=20, vadthreshold=0.6, voicestart=300, voiceend=3400, active=5, pause=8):
35
+ """
36
+ Creates a new Microphone pipeline.
37
+
38
+ Args:
39
+ rate: sample rate to record audio in, defaults to 16000 (16 kHz)
40
+ vadmode: aggressiveness of the voice activity detector (1 - 3), defaults to 3, which is the most aggressive filter
41
+ vadframe: voice activity detector frame size in ms, defaults to 20
42
+ vadthreshold: percentage of frames (0.0 - 1.0) that must be voice to be considered speech, defaults to 0.6
43
+ voicestart: starting frequency to use for voice filtering, defaults to 300
44
+ voiceend: ending frequency to use for voice filtering, defaults to 3400
45
+ active: minimum number of active speech chunks to require before considering this speech, defaults to 5
46
+ pause: number of non-speech chunks to keep before considering speech complete, defaults to 8
47
+ """
48
+
49
+ if not MICROPHONE:
50
+ raise ImportError(
51
+ (
52
+ 'Microphone pipeline is not available - install "pipeline" extra to enable. '
53
+ "Also check that the portaudio system library is available."
54
+ )
55
+ )
56
+
57
+ # Sample rate
58
+ self.rate = rate
59
+
60
+ # Voice activity detector
61
+ self.vad = webrtcvad.Vad(vadmode)
62
+ self.vadframe = vadframe
63
+ self.vadthreshold = vadthreshold
64
+
65
+ # Voice spectrum
66
+ self.voicestart = voicestart
67
+ self.voiceend = voiceend
68
+
69
+ # Audio chunks counts
70
+ self.active = active
71
+ self.pause = pause
72
+
73
+ def __call__(self, device=None):
74
+ """
75
+ Reads audio from an input device.
76
+
77
+ Args:
78
+ device: optional input device id, otherwise uses system default
79
+
80
+ Returns:
81
+ list of (audio, sample rate)
82
+ """
83
+
84
+ # Listen for audio
85
+ audio = self.listen(device[0] if isinstance(device, list) else device)
86
+
87
+ # Return single element if single element passed in
88
+ return (audio, self.rate) if device is None or not isinstance(device, list) else [(audio, self.rate)]
89
+
90
+ def listen(self, device):
91
+ """
92
+ Listens for speech. Detected speech is converted to 32-bit floats for compatibility with
93
+ automatic speech recognition (ASR) pipelines.
94
+
95
+ This method blocks until speech is detected.
96
+
97
+ Args:
98
+ device: input device
99
+
100
+ Returns:
101
+ audio
102
+ """
103
+
104
+ # Record in 100ms chunks
105
+ chunksize = self.rate // 10
106
+
107
+ # Open input stream
108
+ stream = sd.RawInputStream(device=device, samplerate=self.rate, channels=1, blocksize=chunksize, dtype=np.int16)
109
+
110
+ # Start the input stream
111
+ stream.start()
112
+
113
+ record, speech, nospeech, chunks = True, 0, 0, []
114
+ while record:
115
+ # Read chunk
116
+ chunk, _ = stream.read(chunksize)
117
+
118
+ # Detect speech using WebRTC VAD for audio chunk
119
+ detect = self.detect(chunk)
120
+ speech = speech + 1 if detect else speech
121
+ nospeech = 0 if detect else nospeech + 1
122
+
123
+ # Save chunk, if this is an active stream
124
+ if speech:
125
+ chunks.append(chunk)
126
+
127
+ # Pause limit has been reached, check if this audio should be accepted
128
+ if nospeech >= self.pause:
129
+ logger.debug("Audio detected and being analyzed")
130
+ if speech >= self.active and self.isspeech(chunks[:-nospeech]):
131
+ # Disable recording
132
+ record = False
133
+ else:
134
+ # Reset parameters and keep recording
135
+ logger.debug("Speech not detected")
136
+ speech, nospeech, chunks = 0, 0, []
137
+
138
+ # Stop the input stream
139
+ stream.stop()
140
+
141
+ # Convert to float32 and return
142
+ audio = np.frombuffer(b"".join(chunks), np.int16)
143
+ return Signal.float32(audio)
144
+
145
+ def isspeech(self, chunks):
146
+ """
147
+ Runs an ensemble of Voice Activity Detection (VAD) methods. Returns true if speech is
148
+ detected in the input audio chunks.
149
+
150
+ Args:
151
+ chunks: input audio chunks as byte buffers
152
+
153
+ Returns:
154
+ True if speech is detected, False otherwise
155
+ """
156
+
157
+ # Convert to NumPy array for processing
158
+ audio = np.frombuffer(b"".join(chunks), dtype=np.int16)
159
+
160
+ # Ensemble of:
161
+ # - WebRTC VAD with a human voice range butterworth bandpass filter applied to the signal
162
+ # - FFT applied to detect the energy ratio for human voice range vs total range
163
+ return self.detectband(audio) and self.detectenergy(audio)
164
+
165
+ def detect(self, buffer):
166
+ """
167
+ Detect speech using the WebRTC Voice Activity Detector (VAD).
168
+
169
+ Args:
170
+ buffer: input audio buffer frame as bytes
171
+
172
+ Returns:
173
+ True if the number of audio frames with audio pass vadthreshold, False otherwise
174
+ """
175
+
176
+ n = int(self.rate * (self.vadframe / 1000.0) * 2)
177
+ offset = 0
178
+
179
+ detects = []
180
+ while offset + n <= len(buffer):
181
+ detects.append(1 if self.vad.is_speech(buffer[offset : offset + n], self.rate) else 0)
182
+ offset += n
183
+
184
+ # Calculate detection ratio and return
185
+ ratio = sum(detects) / len(detects) if detects else 0
186
+ if ratio > 0:
187
+ logger.debug("DETECT %.4f", ratio)
188
+
189
+ return ratio >= self.vadthreshold
190
+
191
+ def detectband(self, audio):
192
+ """
193
+ Detects speech using audio data filtered through a butterworth band filter
194
+ with the human voice range.
195
+
196
+ Args:
197
+ audio: input audio data as an NumPy array
198
+
199
+ Returns:
200
+ True if speech is detected, False otherwise
201
+ """
202
+
203
+ # Upsample to float32
204
+ audio = Signal.float32(audio)
205
+
206
+ # Human voice frequency range
207
+ low = self.voicestart / (0.5 * self.rate)
208
+ high = self.voiceend / (0.5 * self.rate)
209
+
210
+ # Low and high pass filter using human voice range
211
+ sos = butter(5, Wn=[low, high], btype="band", output="sos")
212
+ audio = sosfilt(sos, audio)
213
+
214
+ # Scale back to int16
215
+ audio = Signal.int16(audio)
216
+
217
+ # Pass filtered signal to WebRTC VAD
218
+ return self.detect(audio.tobytes())
219
+
220
+ def detectenergy(self, audio):
221
+ """
222
+ Detects speech by comparing the signal energy of the human voice range
223
+ to the overall signal energy.
224
+
225
+ Args:
226
+ audio: input audio data as an NumPy array
227
+
228
+ Returns:
229
+ True if speech is detected, False otherwise
230
+ """
231
+
232
+ # Calculate signal energy
233
+ energyfreq = Signal.energy(audio, self.rate)
234
+
235
+ # Sum speech energy
236
+ speechenergy = 0
237
+ for f, e in energyfreq.items():
238
+ if self.voicestart <= f <= self.voiceend:
239
+ speechenergy += e
240
+
241
+ # Calculate ratio of speech energy to total energy and return
242
+ ratio = speechenergy / sum(energyfreq.values())
243
+ logger.debug("SPEECH %.4f", ratio)
244
+ return ratio >= self.vadthreshold