mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/models/registry.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
"""
|
2
|
+
Registry module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification
|
6
|
+
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
|
7
|
+
|
8
|
+
|
9
|
+
class Registry:
|
10
|
+
"""
|
11
|
+
Methods to register models and fully support pipelines.
|
12
|
+
"""
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def register(model, config=None):
|
16
|
+
"""
|
17
|
+
Registers a model with auto model and tokenizer configuration to fully support pipelines.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
model: model to register
|
21
|
+
config: config class name
|
22
|
+
"""
|
23
|
+
|
24
|
+
# Default config class to model class if not provided
|
25
|
+
config = config if config else model.__class__
|
26
|
+
|
27
|
+
# Default model config_class if empty
|
28
|
+
if hasattr(model.__class__, "config_class") and not model.__class__.config_class:
|
29
|
+
model.__class__.config_class = config
|
30
|
+
|
31
|
+
# Add references for this class to supported AutoModel classes
|
32
|
+
for mapping in [AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification]:
|
33
|
+
mapping.register(config, model.__class__)
|
34
|
+
|
35
|
+
# Add references for this class to support pipeline AutoTokenizers
|
36
|
+
if hasattr(model, "config") and type(model.config) not in TOKENIZER_MAPPING:
|
37
|
+
TOKENIZER_MAPPING.register(type(model.config), type(model.config).__name__)
|
@@ -0,0 +1,122 @@
|
|
1
|
+
"""
|
2
|
+
Token Detection module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import inspect
|
6
|
+
import os
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from transformers import PreTrainedModel
|
11
|
+
|
12
|
+
|
13
|
+
class TokenDetection(PreTrainedModel):
|
14
|
+
"""
|
15
|
+
Runs the replaced token detection training objective. This method was first proposed by the ELECTRA model.
|
16
|
+
The method consists of a masked language model generator feeding data to a discriminator that determines
|
17
|
+
which of the tokens are incorrect. More on this training objective can be found in the ELECTRA paper.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self, generator, discriminator, tokenizer, weight=50.0):
|
21
|
+
"""
|
22
|
+
Creates a new TokenDetection class.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
generator: Generator model, must be a masked language model
|
26
|
+
discriminator: Discriminator model, must be a model that can detect replaced tokens. Any model can
|
27
|
+
can be customized for this task. See ElectraForPretraining for more.
|
28
|
+
"""
|
29
|
+
|
30
|
+
# Initialize model with discriminator config
|
31
|
+
super().__init__(discriminator.config)
|
32
|
+
|
33
|
+
self.generator = generator
|
34
|
+
self.discriminator = discriminator
|
35
|
+
|
36
|
+
# Tokenizer to save with generator and discriminator
|
37
|
+
self.tokenizer = tokenizer
|
38
|
+
|
39
|
+
# Discriminator weight
|
40
|
+
self.weight = weight
|
41
|
+
|
42
|
+
# Share embeddings if both models are the same type
|
43
|
+
# Embeddings must be same size
|
44
|
+
if self.generator.config.model_type == self.discriminator.config.model_type:
|
45
|
+
self.discriminator.set_input_embeddings(self.generator.get_input_embeddings())
|
46
|
+
|
47
|
+
# Set attention mask present flags
|
48
|
+
self.gattention = "attention_mask" in inspect.signature(self.generator.forward).parameters
|
49
|
+
self.dattention = "attention_mask" in inspect.signature(self.discriminator.forward).parameters
|
50
|
+
|
51
|
+
# pylint: disable=E1101
|
52
|
+
def forward(self, input_ids=None, labels=None, attention_mask=None, token_type_ids=None):
|
53
|
+
"""
|
54
|
+
Runs a forward pass through the model. This method runs the masked language model then randomly samples
|
55
|
+
the generated tokens and builds a binary classification problem for the discriminator (detecting if each token is correct).
|
56
|
+
|
57
|
+
Args:
|
58
|
+
input_ids: token ids
|
59
|
+
labels: token labels
|
60
|
+
attention_mask: attention mask
|
61
|
+
token_type_ids: segment token indices
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
(loss, generator outputs, discriminator outputs, discriminator labels)
|
65
|
+
"""
|
66
|
+
|
67
|
+
# Copy input ids
|
68
|
+
dinputs = input_ids.clone()
|
69
|
+
|
70
|
+
# Run inputs through masked language model
|
71
|
+
inputs = {"attention_mask": attention_mask} if self.gattention else {}
|
72
|
+
goutputs = self.generator(input_ids, labels=labels, token_type_ids=token_type_ids, **inputs)
|
73
|
+
|
74
|
+
# Get predictions
|
75
|
+
preds = torch.softmax(goutputs[1], dim=-1)
|
76
|
+
preds = preds.view(-1, self.config.vocab_size)
|
77
|
+
|
78
|
+
tokens = torch.multinomial(preds, 1).view(-1)
|
79
|
+
tokens = tokens.view(dinputs.shape[0], -1)
|
80
|
+
|
81
|
+
# Labels have a -100 value to ignore loss from unchanged tokens
|
82
|
+
mask = labels.ne(-100)
|
83
|
+
|
84
|
+
# Replace the masked out tokens of the input with the generator predictions
|
85
|
+
dinputs[mask] = tokens[mask]
|
86
|
+
|
87
|
+
# Turn mask into new target labels - 1 (True) for corrupted, 0 otherwise.
|
88
|
+
# If the prediction was correct, mark it as uncorrupted.
|
89
|
+
correct = tokens == labels
|
90
|
+
dlabels = mask.long()
|
91
|
+
dlabels[correct] = 0
|
92
|
+
|
93
|
+
# Run token classification, predict whether each token was corrupted
|
94
|
+
inputs = {"attention_mask": attention_mask} if self.dattention else {}
|
95
|
+
doutputs = self.discriminator(dinputs, labels=dlabels, token_type_ids=token_type_ids, **inputs)
|
96
|
+
|
97
|
+
# Compute combined loss
|
98
|
+
loss = goutputs[0] + self.weight * doutputs[0]
|
99
|
+
return loss, goutputs[1], doutputs[1], dlabels
|
100
|
+
|
101
|
+
def save_pretrained(self, output, state_dict=None, **kwargs):
|
102
|
+
"""
|
103
|
+
Saves current model to output directory.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
output: output directory
|
107
|
+
state_dict: model state
|
108
|
+
kwargs: additional keyword arguments
|
109
|
+
"""
|
110
|
+
|
111
|
+
# Save combined model to support training from checkpoints
|
112
|
+
super().save_pretrained(output, state_dict, **kwargs)
|
113
|
+
|
114
|
+
# Save generator tokenizer and model
|
115
|
+
gpath = os.path.join(output, "generator")
|
116
|
+
self.tokenizer.save_pretrained(gpath)
|
117
|
+
self.generator.save_pretrained(gpath)
|
118
|
+
|
119
|
+
# Save discriminator tokenizer and model
|
120
|
+
dpath = os.path.join(output, "discriminator")
|
121
|
+
self.tokenizer.save_pretrained(dpath)
|
122
|
+
self.discriminator.save_pretrained(dpath)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""
|
2
|
+
Pipeline imports
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .audio import *
|
6
|
+
from .base import Pipeline
|
7
|
+
from .data import *
|
8
|
+
from .factory import PipelineFactory
|
9
|
+
from .hfmodel import HFModel
|
10
|
+
from .hfpipeline import HFPipeline
|
11
|
+
from .image import *
|
12
|
+
from .llm import *
|
13
|
+
from .llm import RAG as Extractor
|
14
|
+
from .nop import Nop
|
15
|
+
from .text import *
|
16
|
+
from .tensors import Tensors
|
17
|
+
from .train import *
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
Audio imports
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .audiomixer import AudioMixer
|
6
|
+
from .audiostream import AudioStream
|
7
|
+
from .microphone import Microphone
|
8
|
+
from .signal import Signal
|
9
|
+
from .texttoaudio import TextToAudio
|
10
|
+
from .texttospeech import TextToSpeech
|
11
|
+
from .transcription import Transcription
|
@@ -0,0 +1,58 @@
|
|
1
|
+
"""
|
2
|
+
AudioMixer module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..base import Pipeline
|
6
|
+
from .signal import Signal, SCIPY
|
7
|
+
|
8
|
+
|
9
|
+
class AudioMixer(Pipeline):
|
10
|
+
"""
|
11
|
+
Mixes multiple audio streams into a single stream.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, rate=None):
|
15
|
+
"""
|
16
|
+
Creates an AudioMixer pipeline.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
rate: optional target sample rate, otherwise uses input target rate with each audio segment
|
20
|
+
"""
|
21
|
+
|
22
|
+
if not SCIPY:
|
23
|
+
raise ImportError('AudioMixer pipeline is not available - install "pipeline" extra to enable.')
|
24
|
+
|
25
|
+
# Target sample rate
|
26
|
+
self.rate = rate
|
27
|
+
|
28
|
+
def __call__(self, segment, scale1=1, scale2=1):
|
29
|
+
"""
|
30
|
+
Mixes multiple audio streams into a single stream.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
segment: ((audio1, sample rate), (audio2, sample rate))|list
|
34
|
+
scale1: optional scaling factor for segment1
|
35
|
+
scale2: optional scaling factor for segment2
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
list of (audio, sample rate)
|
39
|
+
"""
|
40
|
+
|
41
|
+
# Convert single element to list
|
42
|
+
segments = [segment] if isinstance(segment, tuple) else segment
|
43
|
+
|
44
|
+
results = []
|
45
|
+
for segment1, segment2 in segments:
|
46
|
+
audio1, rate1 = segment1
|
47
|
+
audio2, rate2 = segment2
|
48
|
+
|
49
|
+
# Resample audio, as necessary
|
50
|
+
target = self.rate if self.rate else rate1
|
51
|
+
audio1 = Signal.resample(audio1, rate1, target)
|
52
|
+
audio2 = Signal.resample(audio2, rate2, target)
|
53
|
+
|
54
|
+
# Mix audio into single segment
|
55
|
+
results.append((Signal.mix(audio1, audio2, scale1, scale2), target))
|
56
|
+
|
57
|
+
# Return single element if single element passed in
|
58
|
+
return results[0] if isinstance(segment, tuple) else results
|
@@ -0,0 +1,94 @@
|
|
1
|
+
"""
|
2
|
+
AudioStream module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from queue import Queue
|
6
|
+
from threading import Thread
|
7
|
+
|
8
|
+
# Conditional import
|
9
|
+
try:
|
10
|
+
import sounddevice as sd
|
11
|
+
|
12
|
+
from .signal import Signal, SCIPY
|
13
|
+
|
14
|
+
AUDIOSTREAM = SCIPY
|
15
|
+
except (ImportError, OSError):
|
16
|
+
AUDIOSTREAM = False
|
17
|
+
|
18
|
+
from ..base import Pipeline
|
19
|
+
|
20
|
+
|
21
|
+
class AudioStream(Pipeline):
|
22
|
+
"""
|
23
|
+
Threaded pipeline that streams audio segments to an output audio device. This pipeline is designed
|
24
|
+
to run on local machines given that it requires access to write to an output device.
|
25
|
+
"""
|
26
|
+
|
27
|
+
# End of stream message
|
28
|
+
COMPLETE = (1, None)
|
29
|
+
|
30
|
+
def __init__(self, rate=None):
|
31
|
+
"""
|
32
|
+
Creates an AudioStream pipeline.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
rate: optional target sample rate, otherwise uses input target rate with each audio segment
|
36
|
+
"""
|
37
|
+
|
38
|
+
if not AUDIOSTREAM:
|
39
|
+
raise ImportError(
|
40
|
+
(
|
41
|
+
'AudioStream pipeline is not available - install "pipeline" extra to enable. '
|
42
|
+
"Also check that the portaudio system library is available."
|
43
|
+
)
|
44
|
+
)
|
45
|
+
|
46
|
+
# Target sample rate
|
47
|
+
self.rate = rate
|
48
|
+
|
49
|
+
self.queue = Queue()
|
50
|
+
self.thread = Thread(target=self.play)
|
51
|
+
self.thread.start()
|
52
|
+
|
53
|
+
def __call__(self, segment):
|
54
|
+
"""
|
55
|
+
Queues audio segments for the audio player.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
segment: (audio, sample rate)|list
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
segment
|
62
|
+
"""
|
63
|
+
|
64
|
+
# Convert single element to list
|
65
|
+
segments = [segment] if isinstance(segment, tuple) else segment
|
66
|
+
|
67
|
+
for x in segments:
|
68
|
+
self.queue.put(x)
|
69
|
+
|
70
|
+
# Return single element if single element passed in
|
71
|
+
return segments[0] if isinstance(segment, tuple) else segments
|
72
|
+
|
73
|
+
def wait(self):
|
74
|
+
"""
|
75
|
+
Waits for all input audio segments to be played.
|
76
|
+
"""
|
77
|
+
|
78
|
+
self.thread.join()
|
79
|
+
|
80
|
+
def play(self):
|
81
|
+
"""
|
82
|
+
Reads audio segments from queue. This method runs in a separate non-blocking thread.
|
83
|
+
"""
|
84
|
+
|
85
|
+
audio, rate = self.queue.get()
|
86
|
+
while not isinstance(audio, int) or (audio, rate) != AudioStream.COMPLETE:
|
87
|
+
# Resample to target sample rate, if necessary
|
88
|
+
audio, rate = (Signal.resample(audio, rate, self.rate), self.rate) if self.rate else (audio, rate)
|
89
|
+
|
90
|
+
# Play audio segment
|
91
|
+
sd.play(audio, rate, blocking=True)
|
92
|
+
|
93
|
+
# Get next segment
|
94
|
+
audio, rate = self.queue.get()
|
@@ -0,0 +1,244 @@
|
|
1
|
+
"""
|
2
|
+
Microphone module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
# Conditional import
|
10
|
+
try:
|
11
|
+
import sounddevice as sd
|
12
|
+
import webrtcvad
|
13
|
+
|
14
|
+
from scipy.signal import butter, sosfilt
|
15
|
+
|
16
|
+
from .signal import Signal, SCIPY
|
17
|
+
|
18
|
+
MICROPHONE = SCIPY
|
19
|
+
except (ImportError, OSError):
|
20
|
+
MICROPHONE = False
|
21
|
+
|
22
|
+
from ..base import Pipeline
|
23
|
+
|
24
|
+
# Logging configuration
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
|
28
|
+
class Microphone(Pipeline):
|
29
|
+
"""
|
30
|
+
Reads input speech from a microphone device. This pipeline is designed to run on local machines given
|
31
|
+
that it requires access to read from an input device.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(self, rate=16000, vadmode=3, vadframe=20, vadthreshold=0.6, voicestart=300, voiceend=3400, active=5, pause=8):
|
35
|
+
"""
|
36
|
+
Creates a new Microphone pipeline.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
rate: sample rate to record audio in, defaults to 16000 (16 kHz)
|
40
|
+
vadmode: aggressiveness of the voice activity detector (1 - 3), defaults to 3, which is the most aggressive filter
|
41
|
+
vadframe: voice activity detector frame size in ms, defaults to 20
|
42
|
+
vadthreshold: percentage of frames (0.0 - 1.0) that must be voice to be considered speech, defaults to 0.6
|
43
|
+
voicestart: starting frequency to use for voice filtering, defaults to 300
|
44
|
+
voiceend: ending frequency to use for voice filtering, defaults to 3400
|
45
|
+
active: minimum number of active speech chunks to require before considering this speech, defaults to 5
|
46
|
+
pause: number of non-speech chunks to keep before considering speech complete, defaults to 8
|
47
|
+
"""
|
48
|
+
|
49
|
+
if not MICROPHONE:
|
50
|
+
raise ImportError(
|
51
|
+
(
|
52
|
+
'Microphone pipeline is not available - install "pipeline" extra to enable. '
|
53
|
+
"Also check that the portaudio system library is available."
|
54
|
+
)
|
55
|
+
)
|
56
|
+
|
57
|
+
# Sample rate
|
58
|
+
self.rate = rate
|
59
|
+
|
60
|
+
# Voice activity detector
|
61
|
+
self.vad = webrtcvad.Vad(vadmode)
|
62
|
+
self.vadframe = vadframe
|
63
|
+
self.vadthreshold = vadthreshold
|
64
|
+
|
65
|
+
# Voice spectrum
|
66
|
+
self.voicestart = voicestart
|
67
|
+
self.voiceend = voiceend
|
68
|
+
|
69
|
+
# Audio chunks counts
|
70
|
+
self.active = active
|
71
|
+
self.pause = pause
|
72
|
+
|
73
|
+
def __call__(self, device=None):
|
74
|
+
"""
|
75
|
+
Reads audio from an input device.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
device: optional input device id, otherwise uses system default
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
list of (audio, sample rate)
|
82
|
+
"""
|
83
|
+
|
84
|
+
# Listen for audio
|
85
|
+
audio = self.listen(device[0] if isinstance(device, list) else device)
|
86
|
+
|
87
|
+
# Return single element if single element passed in
|
88
|
+
return (audio, self.rate) if device is None or not isinstance(device, list) else [(audio, self.rate)]
|
89
|
+
|
90
|
+
def listen(self, device):
|
91
|
+
"""
|
92
|
+
Listens for speech. Detected speech is converted to 32-bit floats for compatibility with
|
93
|
+
automatic speech recognition (ASR) pipelines.
|
94
|
+
|
95
|
+
This method blocks until speech is detected.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
device: input device
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
audio
|
102
|
+
"""
|
103
|
+
|
104
|
+
# Record in 100ms chunks
|
105
|
+
chunksize = self.rate // 10
|
106
|
+
|
107
|
+
# Open input stream
|
108
|
+
stream = sd.RawInputStream(device=device, samplerate=self.rate, channels=1, blocksize=chunksize, dtype=np.int16)
|
109
|
+
|
110
|
+
# Start the input stream
|
111
|
+
stream.start()
|
112
|
+
|
113
|
+
record, speech, nospeech, chunks = True, 0, 0, []
|
114
|
+
while record:
|
115
|
+
# Read chunk
|
116
|
+
chunk, _ = stream.read(chunksize)
|
117
|
+
|
118
|
+
# Detect speech using WebRTC VAD for audio chunk
|
119
|
+
detect = self.detect(chunk)
|
120
|
+
speech = speech + 1 if detect else speech
|
121
|
+
nospeech = 0 if detect else nospeech + 1
|
122
|
+
|
123
|
+
# Save chunk, if this is an active stream
|
124
|
+
if speech:
|
125
|
+
chunks.append(chunk)
|
126
|
+
|
127
|
+
# Pause limit has been reached, check if this audio should be accepted
|
128
|
+
if nospeech >= self.pause:
|
129
|
+
logger.debug("Audio detected and being analyzed")
|
130
|
+
if speech >= self.active and self.isspeech(chunks[:-nospeech]):
|
131
|
+
# Disable recording
|
132
|
+
record = False
|
133
|
+
else:
|
134
|
+
# Reset parameters and keep recording
|
135
|
+
logger.debug("Speech not detected")
|
136
|
+
speech, nospeech, chunks = 0, 0, []
|
137
|
+
|
138
|
+
# Stop the input stream
|
139
|
+
stream.stop()
|
140
|
+
|
141
|
+
# Convert to float32 and return
|
142
|
+
audio = np.frombuffer(b"".join(chunks), np.int16)
|
143
|
+
return Signal.float32(audio)
|
144
|
+
|
145
|
+
def isspeech(self, chunks):
|
146
|
+
"""
|
147
|
+
Runs an ensemble of Voice Activity Detection (VAD) methods. Returns true if speech is
|
148
|
+
detected in the input audio chunks.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
chunks: input audio chunks as byte buffers
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
True if speech is detected, False otherwise
|
155
|
+
"""
|
156
|
+
|
157
|
+
# Convert to NumPy array for processing
|
158
|
+
audio = np.frombuffer(b"".join(chunks), dtype=np.int16)
|
159
|
+
|
160
|
+
# Ensemble of:
|
161
|
+
# - WebRTC VAD with a human voice range butterworth bandpass filter applied to the signal
|
162
|
+
# - FFT applied to detect the energy ratio for human voice range vs total range
|
163
|
+
return self.detectband(audio) and self.detectenergy(audio)
|
164
|
+
|
165
|
+
def detect(self, buffer):
|
166
|
+
"""
|
167
|
+
Detect speech using the WebRTC Voice Activity Detector (VAD).
|
168
|
+
|
169
|
+
Args:
|
170
|
+
buffer: input audio buffer frame as bytes
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
True if the number of audio frames with audio pass vadthreshold, False otherwise
|
174
|
+
"""
|
175
|
+
|
176
|
+
n = int(self.rate * (self.vadframe / 1000.0) * 2)
|
177
|
+
offset = 0
|
178
|
+
|
179
|
+
detects = []
|
180
|
+
while offset + n <= len(buffer):
|
181
|
+
detects.append(1 if self.vad.is_speech(buffer[offset : offset + n], self.rate) else 0)
|
182
|
+
offset += n
|
183
|
+
|
184
|
+
# Calculate detection ratio and return
|
185
|
+
ratio = sum(detects) / len(detects) if detects else 0
|
186
|
+
if ratio > 0:
|
187
|
+
logger.debug("DETECT %.4f", ratio)
|
188
|
+
|
189
|
+
return ratio >= self.vadthreshold
|
190
|
+
|
191
|
+
def detectband(self, audio):
|
192
|
+
"""
|
193
|
+
Detects speech using audio data filtered through a butterworth band filter
|
194
|
+
with the human voice range.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
audio: input audio data as an NumPy array
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
True if speech is detected, False otherwise
|
201
|
+
"""
|
202
|
+
|
203
|
+
# Upsample to float32
|
204
|
+
audio = Signal.float32(audio)
|
205
|
+
|
206
|
+
# Human voice frequency range
|
207
|
+
low = self.voicestart / (0.5 * self.rate)
|
208
|
+
high = self.voiceend / (0.5 * self.rate)
|
209
|
+
|
210
|
+
# Low and high pass filter using human voice range
|
211
|
+
sos = butter(5, Wn=[low, high], btype="band", output="sos")
|
212
|
+
audio = sosfilt(sos, audio)
|
213
|
+
|
214
|
+
# Scale back to int16
|
215
|
+
audio = Signal.int16(audio)
|
216
|
+
|
217
|
+
# Pass filtered signal to WebRTC VAD
|
218
|
+
return self.detect(audio.tobytes())
|
219
|
+
|
220
|
+
def detectenergy(self, audio):
|
221
|
+
"""
|
222
|
+
Detects speech by comparing the signal energy of the human voice range
|
223
|
+
to the overall signal energy.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
audio: input audio data as an NumPy array
|
227
|
+
|
228
|
+
Returns:
|
229
|
+
True if speech is detected, False otherwise
|
230
|
+
"""
|
231
|
+
|
232
|
+
# Calculate signal energy
|
233
|
+
energyfreq = Signal.energy(audio, self.rate)
|
234
|
+
|
235
|
+
# Sum speech energy
|
236
|
+
speechenergy = 0
|
237
|
+
for f, e in energyfreq.items():
|
238
|
+
if self.voicestart <= f <= self.voiceend:
|
239
|
+
speechenergy += e
|
240
|
+
|
241
|
+
# Calculate ratio of speech energy to total energy and return
|
242
|
+
ratio = speechenergy / sum(energyfreq.values())
|
243
|
+
logger.debug("SPEECH %.4f", ratio)
|
244
|
+
return ratio >= self.vadthreshold
|