mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,477 @@
|
|
1
|
+
"""
|
2
|
+
RAG module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ...models import Models
|
6
|
+
|
7
|
+
from ..base import Pipeline
|
8
|
+
from ..data import Tokenizer
|
9
|
+
from ..text import Questions
|
10
|
+
from ..text import Similarity
|
11
|
+
|
12
|
+
from .factory import GenerationFactory
|
13
|
+
from .llm import LLM
|
14
|
+
|
15
|
+
|
16
|
+
class RAG(Pipeline):
|
17
|
+
"""
|
18
|
+
Extracts knowledge from content by joining a prompt, context data store and generative model together. The data store can be
|
19
|
+
an embeddings database or a similarity instance with associated input text. The generative model can be a prompt-driven large
|
20
|
+
language model (LLM), an extractive question-answering model or a custom pipeline. This is known as retrieval augmented generation (RAG).
|
21
|
+
"""
|
22
|
+
|
23
|
+
# pylint: disable=R0913
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
similarity,
|
27
|
+
path,
|
28
|
+
quantize=False,
|
29
|
+
gpu=True,
|
30
|
+
model=None,
|
31
|
+
tokenizer=None,
|
32
|
+
minscore=None,
|
33
|
+
mintokens=None,
|
34
|
+
context=None,
|
35
|
+
task=None,
|
36
|
+
output="default",
|
37
|
+
template=None,
|
38
|
+
separator=" ",
|
39
|
+
system=None,
|
40
|
+
**kwargs,
|
41
|
+
):
|
42
|
+
"""
|
43
|
+
Builds a new RAG pipeline.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
similarity: similarity instance (embeddings or similarity pipeline)
|
47
|
+
path: path to model, supports a LLM, Questions or custom pipeline
|
48
|
+
quantize: True if model should be quantized before inference, False otherwise.
|
49
|
+
gpu: if gpu inference should be used (only works if GPUs are available)
|
50
|
+
model: optional existing pipeline model to wrap
|
51
|
+
tokenizer: Tokenizer class
|
52
|
+
minscore: minimum score to include context match, defaults to None
|
53
|
+
mintokens: minimum number of tokens to include context match, defaults to None
|
54
|
+
context: topn context matches to include, defaults to 3
|
55
|
+
task: model task (language-generation, sequence-sequence or question-answering), defaults to auto-detect
|
56
|
+
output: output format, 'default' returns (name, answer), 'flatten' returns answers and 'reference' returns (name, answer, reference)
|
57
|
+
template: prompt template, it must have a parameter for {question} and {context}, defaults to "{question} {context}"
|
58
|
+
separator: context separator
|
59
|
+
system: system prompt, defaults to None
|
60
|
+
kwargs: additional keyword arguments to pass to pipeline model
|
61
|
+
"""
|
62
|
+
|
63
|
+
# Similarity instance
|
64
|
+
self.similarity = similarity
|
65
|
+
|
66
|
+
# Model can be a LLM, Questions or custom pipeline
|
67
|
+
self.model = self.load(path, quantize, gpu, model, task, **kwargs)
|
68
|
+
|
69
|
+
# Tokenizer class use default method if not set
|
70
|
+
self.tokenizer = tokenizer if tokenizer else Tokenizer() if hasattr(self.similarity, "scoring") and self.similarity.isweighted() else None
|
71
|
+
|
72
|
+
# Minimum score to include context match
|
73
|
+
self.minscore = minscore if minscore is not None else 0.0
|
74
|
+
|
75
|
+
# Minimum number of tokens to include context match
|
76
|
+
self.mintokens = mintokens if mintokens is not None else 0.0
|
77
|
+
|
78
|
+
# Top n context matches to include for context
|
79
|
+
self.context = context if context else 3
|
80
|
+
|
81
|
+
# Output format
|
82
|
+
self.output = output
|
83
|
+
|
84
|
+
# Prompt template
|
85
|
+
self.template = template if template else "{question} {context}"
|
86
|
+
|
87
|
+
# Context separator
|
88
|
+
self.separator = separator
|
89
|
+
|
90
|
+
# System prompt template
|
91
|
+
self.system = system
|
92
|
+
|
93
|
+
def __call__(self, queue, texts=None, **kwargs):
|
94
|
+
"""
|
95
|
+
Finds answers to input questions. This method runs queries to find the top n best matches and uses that as the context.
|
96
|
+
A model is then run against the context for each input question, with the answer returned.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
queue: input question queue (name, query, question, snippet), can be list of tuples/dicts/strings or a single input element
|
100
|
+
texts: optional list of text for context, otherwise runs embeddings search
|
101
|
+
kwargs: additional keyword arguments to pass to pipeline model
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
list of answers matching input format (tuple or dict) containing fields as specified by output format
|
105
|
+
"""
|
106
|
+
|
107
|
+
# Save original queue format
|
108
|
+
inputs = queue
|
109
|
+
|
110
|
+
# Convert queue to list, if necessary
|
111
|
+
queue = queue if isinstance(queue, list) else [queue]
|
112
|
+
|
113
|
+
# Convert dictionary inputs to tuples
|
114
|
+
if queue and isinstance(queue[0], dict):
|
115
|
+
# Convert dict to tuple
|
116
|
+
queue = [tuple(row.get(x) for x in ["name", "query", "question", "snippet"]) for row in queue]
|
117
|
+
|
118
|
+
if queue and isinstance(queue[0], str):
|
119
|
+
# Convert string questions to tuple
|
120
|
+
queue = [(None, row, row, None) for row in queue]
|
121
|
+
|
122
|
+
# Rank texts by similarity for each query
|
123
|
+
results = self.query([query for _, query, _, _ in queue], texts)
|
124
|
+
|
125
|
+
# Build question-context pairs
|
126
|
+
names, queries, questions, contexts, topns, snippets = [], [], [], [], [], []
|
127
|
+
for x, (name, query, question, snippet) in enumerate(queue):
|
128
|
+
# Get top n best matching segments
|
129
|
+
topn = sorted(results[x], key=lambda y: y[2], reverse=True)[: self.context]
|
130
|
+
|
131
|
+
# Generate context using ordering from texts, if available, otherwise order by score
|
132
|
+
context = self.separator.join(text for _, text, _ in (sorted(topn, key=lambda y: y[0]) if texts else topn))
|
133
|
+
|
134
|
+
names.append(name)
|
135
|
+
queries.append(query)
|
136
|
+
questions.append(question)
|
137
|
+
contexts.append(context)
|
138
|
+
topns.append(topn)
|
139
|
+
snippets.append(snippet)
|
140
|
+
|
141
|
+
# Run pipeline and return answers
|
142
|
+
answers = self.answers(questions, contexts, **kwargs)
|
143
|
+
|
144
|
+
# Apply output formatting to answers and return
|
145
|
+
return self.apply(inputs, names, queries, answers, topns, snippets) if isinstance(answers, list) else answers
|
146
|
+
|
147
|
+
def load(self, path, quantize, gpu, model, task, **kwargs):
|
148
|
+
"""
|
149
|
+
Loads a LLM, Questions or custom pipeline.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
path: path to model, supports a LLM, Questions or custom pipeline
|
153
|
+
quantize: True if model should be quantized before inference, False otherwise.
|
154
|
+
gpu: if gpu inference should be used (only works if GPUs are available)
|
155
|
+
model: optional existing pipeline model to wrap
|
156
|
+
task: model task (language-generation, sequence-sequence or question-answering), defaults to auto-detect
|
157
|
+
kwargs: additional keyword arguments to pass to pipeline model
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
LLM, Questions or custom pipeline
|
161
|
+
"""
|
162
|
+
|
163
|
+
# Only try to load if path is a string
|
164
|
+
if not isinstance(path, str):
|
165
|
+
return path
|
166
|
+
|
167
|
+
# Attempt to resolve task if not provided
|
168
|
+
task = GenerationFactory.method(path, task)
|
169
|
+
task = Models.task(path, **kwargs) if task == "transformers" else task
|
170
|
+
|
171
|
+
# Load Questions pipeline
|
172
|
+
if task == "question-answering":
|
173
|
+
return Questions(path, quantize, gpu, model, **kwargs)
|
174
|
+
|
175
|
+
# Load LLM pipeline
|
176
|
+
return LLM(path=path, quantize=quantize, gpu=gpu, model=model, task=task, **kwargs)
|
177
|
+
|
178
|
+
def query(self, queries, texts):
|
179
|
+
"""
|
180
|
+
Rank texts by similarity for each query. If texts is empty, an embeddings search will be executed.
|
181
|
+
Returns results sorted by best match.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
queries: list of queries
|
185
|
+
texts: optional list of text
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
list of (id, data, score) per query
|
189
|
+
"""
|
190
|
+
|
191
|
+
if not queries:
|
192
|
+
return []
|
193
|
+
|
194
|
+
# Score text against queries
|
195
|
+
scores, segments, tokenlist = self.score(queries, texts)
|
196
|
+
|
197
|
+
# Build question-context pairs
|
198
|
+
results = []
|
199
|
+
for i, query in enumerate(queries):
|
200
|
+
# Get list of required and prohibited tokens
|
201
|
+
must = [token.strip("+") for token in query.split() if token.startswith("+") and len(token) > 1]
|
202
|
+
mnot = [token.strip("-") for token in query.split() if token.startswith("-") and len(token) > 1]
|
203
|
+
|
204
|
+
# Segment text is static when texts is passed in but different per query when an embeddings search is run
|
205
|
+
segment = segments if texts else segments[i]
|
206
|
+
tokens = tokenlist if texts else tokenlist[i]
|
207
|
+
|
208
|
+
# List of matches
|
209
|
+
matches = []
|
210
|
+
for y, (x, score) in enumerate(scores[i]):
|
211
|
+
# Segments and tokens are statically ordered when texts is passed in, need to resolve values with score id
|
212
|
+
# Scores, segments and tokens all share the same list ordering when an embeddings search is run
|
213
|
+
x = x if texts else y
|
214
|
+
|
215
|
+
# Get segment text
|
216
|
+
text = segment[x][1]
|
217
|
+
|
218
|
+
# Add result if:
|
219
|
+
# - all required tokens are present or there are not required tokens AND
|
220
|
+
# - all prohibited tokens are not present or there are not prohibited tokens
|
221
|
+
# - score is above minimum score required
|
222
|
+
# - number of tokens is above minimum number of tokens required
|
223
|
+
if (not must or all(token.lower() in text.lower() for token in must)) and (
|
224
|
+
not mnot or all(token.lower() not in text.lower() for token in mnot)
|
225
|
+
):
|
226
|
+
if score >= self.minscore and len(tokens[x]) >= self.mintokens:
|
227
|
+
matches.append(segment[x] + (score,))
|
228
|
+
|
229
|
+
# Add query matches sorted by highest score
|
230
|
+
results.append(matches)
|
231
|
+
|
232
|
+
return results
|
233
|
+
|
234
|
+
def score(self, queries, texts):
|
235
|
+
"""
|
236
|
+
Runs queries against texts (or an embeddings search if texts is empty) and builds list of
|
237
|
+
similarity scores for each query-text combination.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
queries: list of queries
|
241
|
+
texts: optional list of text
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
scores, segments, tokenlist
|
245
|
+
"""
|
246
|
+
|
247
|
+
# Tokenize text
|
248
|
+
segments, tokenlist = [], []
|
249
|
+
if texts:
|
250
|
+
for text in texts:
|
251
|
+
# Run tokenizer method, if available, otherwise returns original text
|
252
|
+
tokens = self.tokenize(text)
|
253
|
+
if tokens:
|
254
|
+
segments.append(text)
|
255
|
+
tokenlist.append(tokens)
|
256
|
+
|
257
|
+
# Add index id to segments to preserve ordering after filters
|
258
|
+
segments = list(enumerate(segments))
|
259
|
+
|
260
|
+
# Get list of (id, score) - sorted by highest score per query
|
261
|
+
if isinstance(self.similarity, Similarity):
|
262
|
+
# Score using similarity pipeline
|
263
|
+
scores = self.similarity(queries, [t for _, t in segments])
|
264
|
+
elif texts:
|
265
|
+
# Score using embeddings.batchsimilarity
|
266
|
+
scores = self.similarity.batchsimilarity([self.tokenize(x) for x in queries], tokenlist)
|
267
|
+
else:
|
268
|
+
# Score using embeddings.batchsearch
|
269
|
+
scores, segments, tokenlist = self.batchsearch(queries)
|
270
|
+
|
271
|
+
return scores, segments, tokenlist
|
272
|
+
|
273
|
+
def batchsearch(self, queries):
|
274
|
+
"""
|
275
|
+
Runs a batch embeddings search for a set of queries.
|
276
|
+
|
277
|
+
Args:
|
278
|
+
queries: list of queries to run
|
279
|
+
|
280
|
+
Returns:
|
281
|
+
scores, segments, tokenlist
|
282
|
+
"""
|
283
|
+
|
284
|
+
scores, segments, tokenlist = [], [], []
|
285
|
+
for results in self.similarity.batchsearch([self.tokenize(x) for x in queries], self.context):
|
286
|
+
# Assume embeddings content is enabled and results are dictionaries
|
287
|
+
scores.append([(result["id"], result["score"]) for result in results])
|
288
|
+
segments.append([(result["id"], result["text"]) for result in results])
|
289
|
+
tokenlist.append([self.tokenize(result["text"]) for result in results])
|
290
|
+
|
291
|
+
return scores, segments, tokenlist
|
292
|
+
|
293
|
+
def tokenize(self, text):
|
294
|
+
"""
|
295
|
+
Tokenizes text. Returns original text if tokenizer is not available.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
text: input text
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
tokens if tokenizer available otherwise original text
|
302
|
+
"""
|
303
|
+
|
304
|
+
return self.tokenizer(text) if self.tokenizer else text
|
305
|
+
|
306
|
+
def answers(self, questions, contexts, **kwargs):
|
307
|
+
"""
|
308
|
+
Executes pipeline and formats extracted answers.
|
309
|
+
|
310
|
+
Args:
|
311
|
+
questions: questions
|
312
|
+
contexts: question context
|
313
|
+
kwargs: additional keyword arguments to pass to model
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
answers
|
317
|
+
"""
|
318
|
+
|
319
|
+
# Run model inference with questions pipeline
|
320
|
+
if isinstance(self.model, Questions):
|
321
|
+
return self.model(questions, contexts)
|
322
|
+
|
323
|
+
# Run generator pipeline
|
324
|
+
return self.model(self.prompts(questions, contexts), **kwargs)
|
325
|
+
|
326
|
+
def prompts(self, questions, contexts):
|
327
|
+
"""
|
328
|
+
Builds a list of prompts using the passed in questions and contexts.
|
329
|
+
|
330
|
+
Args:
|
331
|
+
questions: questions
|
332
|
+
contexts: question context
|
333
|
+
|
334
|
+
Returns:
|
335
|
+
prompts
|
336
|
+
"""
|
337
|
+
|
338
|
+
# Format prompts for generator pipeline
|
339
|
+
prompts = []
|
340
|
+
for x, context in enumerate(contexts):
|
341
|
+
# Create input prompt
|
342
|
+
prompt = self.template.format(question=questions[x], context=context)
|
343
|
+
|
344
|
+
# Add system prompt, if necessary
|
345
|
+
if self.system:
|
346
|
+
prompt = [
|
347
|
+
{"role": "system", "content": self.system.format(question=questions[x], context=context)},
|
348
|
+
{"role": "user", "content": prompt},
|
349
|
+
]
|
350
|
+
|
351
|
+
prompts.append(prompt)
|
352
|
+
|
353
|
+
return prompts
|
354
|
+
|
355
|
+
def apply(self, inputs, names, queries, answers, topns, snippets):
|
356
|
+
"""
|
357
|
+
Applies the following formatting rules to answers.
|
358
|
+
- each answer row matches input format (tuple or dict)
|
359
|
+
- if output format is 'flatten' then this method flattens to a list of answers
|
360
|
+
- if output format is 'reference' then a list of (name, answer, reference) is returned
|
361
|
+
- otherwise, if output format is 'default' or anything else list of (name, answer) is returned
|
362
|
+
|
363
|
+
Args:
|
364
|
+
inputs: original inputs
|
365
|
+
names: question identifiers/names
|
366
|
+
queries: list of input queries
|
367
|
+
answers: list of generated answers
|
368
|
+
topns: top n records used for context
|
369
|
+
snippets: flags to enable answer snippets per answer
|
370
|
+
|
371
|
+
Returns:
|
372
|
+
list of answers matching input format (tuple or dict) containing fields as specified by output format
|
373
|
+
"""
|
374
|
+
|
375
|
+
# Resolve answers as snippets
|
376
|
+
answers = self.snippets(names, answers, topns, snippets)
|
377
|
+
|
378
|
+
# Flatten to list of answers and return
|
379
|
+
if self.output == "flatten":
|
380
|
+
answers = [answer for _, answer in answers]
|
381
|
+
else:
|
382
|
+
# Resolve id reference for each answer
|
383
|
+
if self.output == "reference":
|
384
|
+
answers = self.reference(queries, answers, topns)
|
385
|
+
|
386
|
+
# Ensure output format matches input format
|
387
|
+
first = inputs[0] if inputs and isinstance(inputs, list) else inputs
|
388
|
+
if isinstance(first, (dict, str)):
|
389
|
+
# Add name if input queue had name field
|
390
|
+
fields = ["name", "answer", "reference"] if isinstance(first, dict) and "name" in first else [None, "answer", "reference"]
|
391
|
+
answers = [{fields[x]: column for x, column in enumerate(row) if fields[x]} for row in answers]
|
392
|
+
|
393
|
+
# Unpack single answer, if necessary
|
394
|
+
return answers[0] if answers and isinstance(inputs, (tuple, dict, str)) else answers
|
395
|
+
|
396
|
+
def snippets(self, names, answers, topns, snippets):
|
397
|
+
"""
|
398
|
+
Extracts text surrounding the answer within context.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
names: question identifiers/names
|
402
|
+
answers: list of generated answers
|
403
|
+
topns: top n records used for context
|
404
|
+
snippets: flags to enable answer snippets per answer
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
answers resolved as snippets per question, if necessary
|
408
|
+
"""
|
409
|
+
|
410
|
+
# Extract and format answer
|
411
|
+
results = []
|
412
|
+
|
413
|
+
for x, answer in enumerate(answers):
|
414
|
+
# Resolve snippet if necessary
|
415
|
+
if answer and snippets[x]:
|
416
|
+
# Searches for first text element to contain answer
|
417
|
+
for _, text, _ in topns[x]:
|
418
|
+
if answer in text:
|
419
|
+
answer = text
|
420
|
+
break
|
421
|
+
|
422
|
+
results.append((names[x], answer))
|
423
|
+
|
424
|
+
return results
|
425
|
+
|
426
|
+
def reference(self, queries, answers, topns):
|
427
|
+
"""
|
428
|
+
Reference each answer with the best matching context element id.
|
429
|
+
|
430
|
+
Args:
|
431
|
+
queries: list of input queries
|
432
|
+
answers: list of answers
|
433
|
+
topn: top n context elements as (id, data, tag)
|
434
|
+
|
435
|
+
Returns:
|
436
|
+
list of (name, answer, reference)
|
437
|
+
"""
|
438
|
+
|
439
|
+
# Convert queries to terms
|
440
|
+
terms = self.terms(queries)
|
441
|
+
|
442
|
+
outputs = []
|
443
|
+
for x, (name, answer) in enumerate(answers):
|
444
|
+
# Get matching topn
|
445
|
+
topn, reference = topns[x], None
|
446
|
+
|
447
|
+
if topn:
|
448
|
+
# Build query from keyword terms and the answer text
|
449
|
+
query = f"{terms[x]} {answers[x][1]}"
|
450
|
+
|
451
|
+
# Compare answer to topns to find best match
|
452
|
+
scores, _, _ = self.score([query], [text for _, text, _ in topn])
|
453
|
+
|
454
|
+
# Get top score index
|
455
|
+
index = scores[0][0][0]
|
456
|
+
|
457
|
+
# Use matching topn id as reference
|
458
|
+
reference = topn[index][0]
|
459
|
+
|
460
|
+
# Append (name, answer, reference) tuple
|
461
|
+
outputs.append((name, answer, reference))
|
462
|
+
|
463
|
+
return outputs
|
464
|
+
|
465
|
+
def terms(self, queries):
|
466
|
+
"""
|
467
|
+
Extracts keyword terms from a list of queries using underlying similarity model.
|
468
|
+
|
469
|
+
Args:
|
470
|
+
queries: list of queries
|
471
|
+
|
472
|
+
Returns:
|
473
|
+
list of queries reduced down to keyword term strings
|
474
|
+
"""
|
475
|
+
|
476
|
+
# Extract keyword terms from queries if underlying similarity model supports it
|
477
|
+
return self.similarity.batchterms(queries) if hasattr(self.similarity, "batchterms") else queries
|
txtai/pipeline/nop.py
ADDED
@@ -0,0 +1,52 @@
|
|
1
|
+
"""
|
2
|
+
Tensor processing framework module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from .base import Pipeline
|
8
|
+
|
9
|
+
|
10
|
+
class Tensors(Pipeline):
|
11
|
+
"""
|
12
|
+
Pipeline backed by a tensor processing framework. Currently supports PyTorch.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def quantize(self, model):
|
16
|
+
"""
|
17
|
+
Quantizes input model and returns. This only is supported for CPU devices.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
model: torch model
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
quantized torch model
|
24
|
+
"""
|
25
|
+
|
26
|
+
# pylint: disable=E1101
|
27
|
+
return torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
28
|
+
|
29
|
+
def tensor(self, data):
|
30
|
+
"""
|
31
|
+
Creates a tensor array.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
data: input data
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
tensor
|
38
|
+
"""
|
39
|
+
|
40
|
+
# pylint: disable=E1102
|
41
|
+
return torch.tensor(data)
|
42
|
+
|
43
|
+
def context(self):
|
44
|
+
"""
|
45
|
+
Defines a context used to wrap processing with the tensor processing framework.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
processing context
|
49
|
+
"""
|
50
|
+
|
51
|
+
# pylint: disable=E1101
|
52
|
+
return torch.no_grad()
|
@@ -0,0 +1,13 @@
|
|
1
|
+
"""
|
2
|
+
Text imports
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .crossencoder import CrossEncoder
|
6
|
+
from .entity import Entity
|
7
|
+
from .labels import Labels
|
8
|
+
from .lateencoder import LateEncoder
|
9
|
+
from .questions import Questions
|
10
|
+
from .reranker import Reranker
|
11
|
+
from .similarity import Similarity
|
12
|
+
from .summary import Summary
|
13
|
+
from .translation import Translation
|
@@ -0,0 +1,70 @@
|
|
1
|
+
"""
|
2
|
+
CrossEncoder module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from ..hfpipeline import HFPipeline
|
8
|
+
|
9
|
+
|
10
|
+
class CrossEncoder(HFPipeline):
|
11
|
+
"""
|
12
|
+
Computes similarity between query and list of text using a cross-encoder model
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
|
16
|
+
super().__init__("text-classification", path, quantize, gpu, model, **kwargs)
|
17
|
+
|
18
|
+
def __call__(self, query, texts, multilabel=True, workers=0):
|
19
|
+
"""
|
20
|
+
Computes the similarity between query and list of text. Returns a list of
|
21
|
+
(id, score) sorted by highest score, where id is the index in texts.
|
22
|
+
|
23
|
+
This method supports query as a string or a list. If the input is a string,
|
24
|
+
the return type is a 1D list of (id, score). If text is a list, a 2D list
|
25
|
+
of (id, score) is returned with a row per string.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
query: query text|list
|
29
|
+
texts: list of text
|
30
|
+
multilabel: labels are independent if True, scores are normalized to sum to 1 per text item if False, raw scores returned if None
|
31
|
+
workers: number of concurrent workers to use for processing data, defaults to None
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
list of (id, score)
|
35
|
+
"""
|
36
|
+
|
37
|
+
scores = []
|
38
|
+
for q in [query] if isinstance(query, str) else query:
|
39
|
+
# Pass (query, text) pairs to model
|
40
|
+
result = self.pipeline([{"text": q, "text_pair": t} for t in texts], top_k=None, function_to_apply="none", num_workers=workers)
|
41
|
+
|
42
|
+
# Apply score transform function
|
43
|
+
scores.append(self.function([r[0]["score"] for r in result], multilabel))
|
44
|
+
|
45
|
+
# Build list of (id, score) per query sorted by highest score
|
46
|
+
scores = [sorted(enumerate(row), key=lambda x: x[1], reverse=True) for row in scores]
|
47
|
+
|
48
|
+
return scores[0] if isinstance(query, str) else scores
|
49
|
+
|
50
|
+
def function(self, scores, multilabel):
|
51
|
+
"""
|
52
|
+
Applys an output transformation function based on value of multilabel.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
scores: input scores
|
56
|
+
multilabel: labels are independent if True, scores are normalized to sum to 1 per text item if False, raw scores returned if None
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
transformed scores
|
60
|
+
"""
|
61
|
+
|
62
|
+
# Output functions
|
63
|
+
# pylint: disable=C3001
|
64
|
+
identity = lambda x: x
|
65
|
+
sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x))
|
66
|
+
softmax = lambda x: np.exp(x) / np.sum(np.exp(x))
|
67
|
+
function = identity if multilabel is None else sigmoid if multilabel else softmax
|
68
|
+
|
69
|
+
# Apply output function
|
70
|
+
return function(np.array(scores))
|