mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,477 @@
1
+ """
2
+ RAG module
3
+ """
4
+
5
+ from ...models import Models
6
+
7
+ from ..base import Pipeline
8
+ from ..data import Tokenizer
9
+ from ..text import Questions
10
+ from ..text import Similarity
11
+
12
+ from .factory import GenerationFactory
13
+ from .llm import LLM
14
+
15
+
16
+ class RAG(Pipeline):
17
+ """
18
+ Extracts knowledge from content by joining a prompt, context data store and generative model together. The data store can be
19
+ an embeddings database or a similarity instance with associated input text. The generative model can be a prompt-driven large
20
+ language model (LLM), an extractive question-answering model or a custom pipeline. This is known as retrieval augmented generation (RAG).
21
+ """
22
+
23
+ # pylint: disable=R0913
24
+ def __init__(
25
+ self,
26
+ similarity,
27
+ path,
28
+ quantize=False,
29
+ gpu=True,
30
+ model=None,
31
+ tokenizer=None,
32
+ minscore=None,
33
+ mintokens=None,
34
+ context=None,
35
+ task=None,
36
+ output="default",
37
+ template=None,
38
+ separator=" ",
39
+ system=None,
40
+ **kwargs,
41
+ ):
42
+ """
43
+ Builds a new RAG pipeline.
44
+
45
+ Args:
46
+ similarity: similarity instance (embeddings or similarity pipeline)
47
+ path: path to model, supports a LLM, Questions or custom pipeline
48
+ quantize: True if model should be quantized before inference, False otherwise.
49
+ gpu: if gpu inference should be used (only works if GPUs are available)
50
+ model: optional existing pipeline model to wrap
51
+ tokenizer: Tokenizer class
52
+ minscore: minimum score to include context match, defaults to None
53
+ mintokens: minimum number of tokens to include context match, defaults to None
54
+ context: topn context matches to include, defaults to 3
55
+ task: model task (language-generation, sequence-sequence or question-answering), defaults to auto-detect
56
+ output: output format, 'default' returns (name, answer), 'flatten' returns answers and 'reference' returns (name, answer, reference)
57
+ template: prompt template, it must have a parameter for {question} and {context}, defaults to "{question} {context}"
58
+ separator: context separator
59
+ system: system prompt, defaults to None
60
+ kwargs: additional keyword arguments to pass to pipeline model
61
+ """
62
+
63
+ # Similarity instance
64
+ self.similarity = similarity
65
+
66
+ # Model can be a LLM, Questions or custom pipeline
67
+ self.model = self.load(path, quantize, gpu, model, task, **kwargs)
68
+
69
+ # Tokenizer class use default method if not set
70
+ self.tokenizer = tokenizer if tokenizer else Tokenizer() if hasattr(self.similarity, "scoring") and self.similarity.isweighted() else None
71
+
72
+ # Minimum score to include context match
73
+ self.minscore = minscore if minscore is not None else 0.0
74
+
75
+ # Minimum number of tokens to include context match
76
+ self.mintokens = mintokens if mintokens is not None else 0.0
77
+
78
+ # Top n context matches to include for context
79
+ self.context = context if context else 3
80
+
81
+ # Output format
82
+ self.output = output
83
+
84
+ # Prompt template
85
+ self.template = template if template else "{question} {context}"
86
+
87
+ # Context separator
88
+ self.separator = separator
89
+
90
+ # System prompt template
91
+ self.system = system
92
+
93
+ def __call__(self, queue, texts=None, **kwargs):
94
+ """
95
+ Finds answers to input questions. This method runs queries to find the top n best matches and uses that as the context.
96
+ A model is then run against the context for each input question, with the answer returned.
97
+
98
+ Args:
99
+ queue: input question queue (name, query, question, snippet), can be list of tuples/dicts/strings or a single input element
100
+ texts: optional list of text for context, otherwise runs embeddings search
101
+ kwargs: additional keyword arguments to pass to pipeline model
102
+
103
+ Returns:
104
+ list of answers matching input format (tuple or dict) containing fields as specified by output format
105
+ """
106
+
107
+ # Save original queue format
108
+ inputs = queue
109
+
110
+ # Convert queue to list, if necessary
111
+ queue = queue if isinstance(queue, list) else [queue]
112
+
113
+ # Convert dictionary inputs to tuples
114
+ if queue and isinstance(queue[0], dict):
115
+ # Convert dict to tuple
116
+ queue = [tuple(row.get(x) for x in ["name", "query", "question", "snippet"]) for row in queue]
117
+
118
+ if queue and isinstance(queue[0], str):
119
+ # Convert string questions to tuple
120
+ queue = [(None, row, row, None) for row in queue]
121
+
122
+ # Rank texts by similarity for each query
123
+ results = self.query([query for _, query, _, _ in queue], texts)
124
+
125
+ # Build question-context pairs
126
+ names, queries, questions, contexts, topns, snippets = [], [], [], [], [], []
127
+ for x, (name, query, question, snippet) in enumerate(queue):
128
+ # Get top n best matching segments
129
+ topn = sorted(results[x], key=lambda y: y[2], reverse=True)[: self.context]
130
+
131
+ # Generate context using ordering from texts, if available, otherwise order by score
132
+ context = self.separator.join(text for _, text, _ in (sorted(topn, key=lambda y: y[0]) if texts else topn))
133
+
134
+ names.append(name)
135
+ queries.append(query)
136
+ questions.append(question)
137
+ contexts.append(context)
138
+ topns.append(topn)
139
+ snippets.append(snippet)
140
+
141
+ # Run pipeline and return answers
142
+ answers = self.answers(questions, contexts, **kwargs)
143
+
144
+ # Apply output formatting to answers and return
145
+ return self.apply(inputs, names, queries, answers, topns, snippets) if isinstance(answers, list) else answers
146
+
147
+ def load(self, path, quantize, gpu, model, task, **kwargs):
148
+ """
149
+ Loads a LLM, Questions or custom pipeline.
150
+
151
+ Args:
152
+ path: path to model, supports a LLM, Questions or custom pipeline
153
+ quantize: True if model should be quantized before inference, False otherwise.
154
+ gpu: if gpu inference should be used (only works if GPUs are available)
155
+ model: optional existing pipeline model to wrap
156
+ task: model task (language-generation, sequence-sequence or question-answering), defaults to auto-detect
157
+ kwargs: additional keyword arguments to pass to pipeline model
158
+
159
+ Returns:
160
+ LLM, Questions or custom pipeline
161
+ """
162
+
163
+ # Only try to load if path is a string
164
+ if not isinstance(path, str):
165
+ return path
166
+
167
+ # Attempt to resolve task if not provided
168
+ task = GenerationFactory.method(path, task)
169
+ task = Models.task(path, **kwargs) if task == "transformers" else task
170
+
171
+ # Load Questions pipeline
172
+ if task == "question-answering":
173
+ return Questions(path, quantize, gpu, model, **kwargs)
174
+
175
+ # Load LLM pipeline
176
+ return LLM(path=path, quantize=quantize, gpu=gpu, model=model, task=task, **kwargs)
177
+
178
+ def query(self, queries, texts):
179
+ """
180
+ Rank texts by similarity for each query. If texts is empty, an embeddings search will be executed.
181
+ Returns results sorted by best match.
182
+
183
+ Args:
184
+ queries: list of queries
185
+ texts: optional list of text
186
+
187
+ Returns:
188
+ list of (id, data, score) per query
189
+ """
190
+
191
+ if not queries:
192
+ return []
193
+
194
+ # Score text against queries
195
+ scores, segments, tokenlist = self.score(queries, texts)
196
+
197
+ # Build question-context pairs
198
+ results = []
199
+ for i, query in enumerate(queries):
200
+ # Get list of required and prohibited tokens
201
+ must = [token.strip("+") for token in query.split() if token.startswith("+") and len(token) > 1]
202
+ mnot = [token.strip("-") for token in query.split() if token.startswith("-") and len(token) > 1]
203
+
204
+ # Segment text is static when texts is passed in but different per query when an embeddings search is run
205
+ segment = segments if texts else segments[i]
206
+ tokens = tokenlist if texts else tokenlist[i]
207
+
208
+ # List of matches
209
+ matches = []
210
+ for y, (x, score) in enumerate(scores[i]):
211
+ # Segments and tokens are statically ordered when texts is passed in, need to resolve values with score id
212
+ # Scores, segments and tokens all share the same list ordering when an embeddings search is run
213
+ x = x if texts else y
214
+
215
+ # Get segment text
216
+ text = segment[x][1]
217
+
218
+ # Add result if:
219
+ # - all required tokens are present or there are not required tokens AND
220
+ # - all prohibited tokens are not present or there are not prohibited tokens
221
+ # - score is above minimum score required
222
+ # - number of tokens is above minimum number of tokens required
223
+ if (not must or all(token.lower() in text.lower() for token in must)) and (
224
+ not mnot or all(token.lower() not in text.lower() for token in mnot)
225
+ ):
226
+ if score >= self.minscore and len(tokens[x]) >= self.mintokens:
227
+ matches.append(segment[x] + (score,))
228
+
229
+ # Add query matches sorted by highest score
230
+ results.append(matches)
231
+
232
+ return results
233
+
234
+ def score(self, queries, texts):
235
+ """
236
+ Runs queries against texts (or an embeddings search if texts is empty) and builds list of
237
+ similarity scores for each query-text combination.
238
+
239
+ Args:
240
+ queries: list of queries
241
+ texts: optional list of text
242
+
243
+ Returns:
244
+ scores, segments, tokenlist
245
+ """
246
+
247
+ # Tokenize text
248
+ segments, tokenlist = [], []
249
+ if texts:
250
+ for text in texts:
251
+ # Run tokenizer method, if available, otherwise returns original text
252
+ tokens = self.tokenize(text)
253
+ if tokens:
254
+ segments.append(text)
255
+ tokenlist.append(tokens)
256
+
257
+ # Add index id to segments to preserve ordering after filters
258
+ segments = list(enumerate(segments))
259
+
260
+ # Get list of (id, score) - sorted by highest score per query
261
+ if isinstance(self.similarity, Similarity):
262
+ # Score using similarity pipeline
263
+ scores = self.similarity(queries, [t for _, t in segments])
264
+ elif texts:
265
+ # Score using embeddings.batchsimilarity
266
+ scores = self.similarity.batchsimilarity([self.tokenize(x) for x in queries], tokenlist)
267
+ else:
268
+ # Score using embeddings.batchsearch
269
+ scores, segments, tokenlist = self.batchsearch(queries)
270
+
271
+ return scores, segments, tokenlist
272
+
273
+ def batchsearch(self, queries):
274
+ """
275
+ Runs a batch embeddings search for a set of queries.
276
+
277
+ Args:
278
+ queries: list of queries to run
279
+
280
+ Returns:
281
+ scores, segments, tokenlist
282
+ """
283
+
284
+ scores, segments, tokenlist = [], [], []
285
+ for results in self.similarity.batchsearch([self.tokenize(x) for x in queries], self.context):
286
+ # Assume embeddings content is enabled and results are dictionaries
287
+ scores.append([(result["id"], result["score"]) for result in results])
288
+ segments.append([(result["id"], result["text"]) for result in results])
289
+ tokenlist.append([self.tokenize(result["text"]) for result in results])
290
+
291
+ return scores, segments, tokenlist
292
+
293
+ def tokenize(self, text):
294
+ """
295
+ Tokenizes text. Returns original text if tokenizer is not available.
296
+
297
+ Args:
298
+ text: input text
299
+
300
+ Returns:
301
+ tokens if tokenizer available otherwise original text
302
+ """
303
+
304
+ return self.tokenizer(text) if self.tokenizer else text
305
+
306
+ def answers(self, questions, contexts, **kwargs):
307
+ """
308
+ Executes pipeline and formats extracted answers.
309
+
310
+ Args:
311
+ questions: questions
312
+ contexts: question context
313
+ kwargs: additional keyword arguments to pass to model
314
+
315
+ Returns:
316
+ answers
317
+ """
318
+
319
+ # Run model inference with questions pipeline
320
+ if isinstance(self.model, Questions):
321
+ return self.model(questions, contexts)
322
+
323
+ # Run generator pipeline
324
+ return self.model(self.prompts(questions, contexts), **kwargs)
325
+
326
+ def prompts(self, questions, contexts):
327
+ """
328
+ Builds a list of prompts using the passed in questions and contexts.
329
+
330
+ Args:
331
+ questions: questions
332
+ contexts: question context
333
+
334
+ Returns:
335
+ prompts
336
+ """
337
+
338
+ # Format prompts for generator pipeline
339
+ prompts = []
340
+ for x, context in enumerate(contexts):
341
+ # Create input prompt
342
+ prompt = self.template.format(question=questions[x], context=context)
343
+
344
+ # Add system prompt, if necessary
345
+ if self.system:
346
+ prompt = [
347
+ {"role": "system", "content": self.system.format(question=questions[x], context=context)},
348
+ {"role": "user", "content": prompt},
349
+ ]
350
+
351
+ prompts.append(prompt)
352
+
353
+ return prompts
354
+
355
+ def apply(self, inputs, names, queries, answers, topns, snippets):
356
+ """
357
+ Applies the following formatting rules to answers.
358
+ - each answer row matches input format (tuple or dict)
359
+ - if output format is 'flatten' then this method flattens to a list of answers
360
+ - if output format is 'reference' then a list of (name, answer, reference) is returned
361
+ - otherwise, if output format is 'default' or anything else list of (name, answer) is returned
362
+
363
+ Args:
364
+ inputs: original inputs
365
+ names: question identifiers/names
366
+ queries: list of input queries
367
+ answers: list of generated answers
368
+ topns: top n records used for context
369
+ snippets: flags to enable answer snippets per answer
370
+
371
+ Returns:
372
+ list of answers matching input format (tuple or dict) containing fields as specified by output format
373
+ """
374
+
375
+ # Resolve answers as snippets
376
+ answers = self.snippets(names, answers, topns, snippets)
377
+
378
+ # Flatten to list of answers and return
379
+ if self.output == "flatten":
380
+ answers = [answer for _, answer in answers]
381
+ else:
382
+ # Resolve id reference for each answer
383
+ if self.output == "reference":
384
+ answers = self.reference(queries, answers, topns)
385
+
386
+ # Ensure output format matches input format
387
+ first = inputs[0] if inputs and isinstance(inputs, list) else inputs
388
+ if isinstance(first, (dict, str)):
389
+ # Add name if input queue had name field
390
+ fields = ["name", "answer", "reference"] if isinstance(first, dict) and "name" in first else [None, "answer", "reference"]
391
+ answers = [{fields[x]: column for x, column in enumerate(row) if fields[x]} for row in answers]
392
+
393
+ # Unpack single answer, if necessary
394
+ return answers[0] if answers and isinstance(inputs, (tuple, dict, str)) else answers
395
+
396
+ def snippets(self, names, answers, topns, snippets):
397
+ """
398
+ Extracts text surrounding the answer within context.
399
+
400
+ Args:
401
+ names: question identifiers/names
402
+ answers: list of generated answers
403
+ topns: top n records used for context
404
+ snippets: flags to enable answer snippets per answer
405
+
406
+ Returns:
407
+ answers resolved as snippets per question, if necessary
408
+ """
409
+
410
+ # Extract and format answer
411
+ results = []
412
+
413
+ for x, answer in enumerate(answers):
414
+ # Resolve snippet if necessary
415
+ if answer and snippets[x]:
416
+ # Searches for first text element to contain answer
417
+ for _, text, _ in topns[x]:
418
+ if answer in text:
419
+ answer = text
420
+ break
421
+
422
+ results.append((names[x], answer))
423
+
424
+ return results
425
+
426
+ def reference(self, queries, answers, topns):
427
+ """
428
+ Reference each answer with the best matching context element id.
429
+
430
+ Args:
431
+ queries: list of input queries
432
+ answers: list of answers
433
+ topn: top n context elements as (id, data, tag)
434
+
435
+ Returns:
436
+ list of (name, answer, reference)
437
+ """
438
+
439
+ # Convert queries to terms
440
+ terms = self.terms(queries)
441
+
442
+ outputs = []
443
+ for x, (name, answer) in enumerate(answers):
444
+ # Get matching topn
445
+ topn, reference = topns[x], None
446
+
447
+ if topn:
448
+ # Build query from keyword terms and the answer text
449
+ query = f"{terms[x]} {answers[x][1]}"
450
+
451
+ # Compare answer to topns to find best match
452
+ scores, _, _ = self.score([query], [text for _, text, _ in topn])
453
+
454
+ # Get top score index
455
+ index = scores[0][0][0]
456
+
457
+ # Use matching topn id as reference
458
+ reference = topn[index][0]
459
+
460
+ # Append (name, answer, reference) tuple
461
+ outputs.append((name, answer, reference))
462
+
463
+ return outputs
464
+
465
+ def terms(self, queries):
466
+ """
467
+ Extracts keyword terms from a list of queries using underlying similarity model.
468
+
469
+ Args:
470
+ queries: list of queries
471
+
472
+ Returns:
473
+ list of queries reduced down to keyword term strings
474
+ """
475
+
476
+ # Extract keyword terms from queries if underlying similarity model supports it
477
+ return self.similarity.batchterms(queries) if hasattr(self.similarity, "batchterms") else queries
txtai/pipeline/nop.py ADDED
@@ -0,0 +1,14 @@
1
+ """
2
+ No-Op module
3
+ """
4
+
5
+ from .base import Pipeline
6
+
7
+
8
+ class Nop(Pipeline):
9
+ """
10
+ Simple no-op pipeline that returns inputs
11
+ """
12
+
13
+ def __call__(self, inputs):
14
+ return inputs
@@ -0,0 +1,52 @@
1
+ """
2
+ Tensor processing framework module
3
+ """
4
+
5
+ import torch
6
+
7
+ from .base import Pipeline
8
+
9
+
10
+ class Tensors(Pipeline):
11
+ """
12
+ Pipeline backed by a tensor processing framework. Currently supports PyTorch.
13
+ """
14
+
15
+ def quantize(self, model):
16
+ """
17
+ Quantizes input model and returns. This only is supported for CPU devices.
18
+
19
+ Args:
20
+ model: torch model
21
+
22
+ Returns:
23
+ quantized torch model
24
+ """
25
+
26
+ # pylint: disable=E1101
27
+ return torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
28
+
29
+ def tensor(self, data):
30
+ """
31
+ Creates a tensor array.
32
+
33
+ Args:
34
+ data: input data
35
+
36
+ Returns:
37
+ tensor
38
+ """
39
+
40
+ # pylint: disable=E1102
41
+ return torch.tensor(data)
42
+
43
+ def context(self):
44
+ """
45
+ Defines a context used to wrap processing with the tensor processing framework.
46
+
47
+ Returns:
48
+ processing context
49
+ """
50
+
51
+ # pylint: disable=E1101
52
+ return torch.no_grad()
@@ -0,0 +1,13 @@
1
+ """
2
+ Text imports
3
+ """
4
+
5
+ from .crossencoder import CrossEncoder
6
+ from .entity import Entity
7
+ from .labels import Labels
8
+ from .lateencoder import LateEncoder
9
+ from .questions import Questions
10
+ from .reranker import Reranker
11
+ from .similarity import Similarity
12
+ from .summary import Summary
13
+ from .translation import Translation
@@ -0,0 +1,70 @@
1
+ """
2
+ CrossEncoder module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ from ..hfpipeline import HFPipeline
8
+
9
+
10
+ class CrossEncoder(HFPipeline):
11
+ """
12
+ Computes similarity between query and list of text using a cross-encoder model
13
+ """
14
+
15
+ def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
16
+ super().__init__("text-classification", path, quantize, gpu, model, **kwargs)
17
+
18
+ def __call__(self, query, texts, multilabel=True, workers=0):
19
+ """
20
+ Computes the similarity between query and list of text. Returns a list of
21
+ (id, score) sorted by highest score, where id is the index in texts.
22
+
23
+ This method supports query as a string or a list. If the input is a string,
24
+ the return type is a 1D list of (id, score). If text is a list, a 2D list
25
+ of (id, score) is returned with a row per string.
26
+
27
+ Args:
28
+ query: query text|list
29
+ texts: list of text
30
+ multilabel: labels are independent if True, scores are normalized to sum to 1 per text item if False, raw scores returned if None
31
+ workers: number of concurrent workers to use for processing data, defaults to None
32
+
33
+ Returns:
34
+ list of (id, score)
35
+ """
36
+
37
+ scores = []
38
+ for q in [query] if isinstance(query, str) else query:
39
+ # Pass (query, text) pairs to model
40
+ result = self.pipeline([{"text": q, "text_pair": t} for t in texts], top_k=None, function_to_apply="none", num_workers=workers)
41
+
42
+ # Apply score transform function
43
+ scores.append(self.function([r[0]["score"] for r in result], multilabel))
44
+
45
+ # Build list of (id, score) per query sorted by highest score
46
+ scores = [sorted(enumerate(row), key=lambda x: x[1], reverse=True) for row in scores]
47
+
48
+ return scores[0] if isinstance(query, str) else scores
49
+
50
+ def function(self, scores, multilabel):
51
+ """
52
+ Applys an output transformation function based on value of multilabel.
53
+
54
+ Args:
55
+ scores: input scores
56
+ multilabel: labels are independent if True, scores are normalized to sum to 1 per text item if False, raw scores returned if None
57
+
58
+ Returns:
59
+ transformed scores
60
+ """
61
+
62
+ # Output functions
63
+ # pylint: disable=C3001
64
+ identity = lambda x: x
65
+ sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x))
66
+ softmax = lambda x: np.exp(x) / np.sum(np.exp(x))
67
+ function = identity if multilabel is None else sigmoid if multilabel else softmax
68
+
69
+ # Apply output function
70
+ return function(np.array(scores))