mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
txtai/scoring/tfidf.py ADDED
@@ -0,0 +1,358 @@
1
+ """
2
+ TFIDF module
3
+ """
4
+
5
+ import math
6
+ import os
7
+
8
+ from collections import Counter
9
+ from multiprocessing.pool import ThreadPool
10
+
11
+ import numpy as np
12
+
13
+ from ..pipeline import Tokenizer
14
+ from ..serialize import Serializer
15
+
16
+ from .base import Scoring
17
+ from .terms import Terms
18
+
19
+
20
+ class TFIDF(Scoring):
21
+ """
22
+ Term frequency-inverse document frequency (TF-IDF) scoring.
23
+ """
24
+
25
+ def __init__(self, config=None):
26
+ super().__init__(config)
27
+
28
+ # Document stats
29
+ self.total = 0
30
+ self.tokens = 0
31
+ self.avgdl = 0
32
+
33
+ # Word frequency
34
+ self.docfreq = Counter()
35
+ self.wordfreq = Counter()
36
+ self.avgfreq = 0
37
+
38
+ # IDF index
39
+ self.idf = {}
40
+ self.avgidf = 0
41
+
42
+ # Tag boosting
43
+ self.tags = Counter()
44
+
45
+ # Tokenizer, lazily loaded as needed
46
+ self.tokenizer = None
47
+
48
+ # Term index
49
+ self.terms = Terms(self.config["terms"], self.score, self.idf) if self.config.get("terms") else None
50
+
51
+ # Document data
52
+ self.documents = {} if self.config.get("content") else None
53
+
54
+ # Normalize scores
55
+ self.normalize = self.config.get("normalize")
56
+ self.avgscore = None
57
+
58
+ def insert(self, documents, index=None, checkpoint=None):
59
+ # Insert documents, calculate word frequency, total tokens and total documents
60
+ for uid, document, tags in documents:
61
+ # Extract text, if necessary
62
+ if isinstance(document, dict):
63
+ document = document.get(self.text, document.get(self.object))
64
+
65
+ if document is not None:
66
+ # If index is passed, use indexid, otherwise use id
67
+ uid = index if index is not None else uid
68
+
69
+ # Add entry to index if the data type is accepted
70
+ if isinstance(document, (str, list)):
71
+ # Store content
72
+ if self.documents is not None:
73
+ self.documents[uid] = document
74
+
75
+ # Convert to tokens, if necessary
76
+ tokens = self.tokenize(document) if isinstance(document, str) else document
77
+
78
+ # Add tokens for id to term index
79
+ if self.terms is not None:
80
+ self.terms.insert(uid, tokens)
81
+
82
+ # Add tokens and tags to stats
83
+ self.addstats(tokens, tags)
84
+
85
+ # Increment index
86
+ index = index + 1 if index is not None else None
87
+
88
+ def delete(self, ids):
89
+ # Delete from terms index
90
+ if self.terms:
91
+ self.terms.delete(ids)
92
+
93
+ # Delete content
94
+ if self.documents:
95
+ for uid in ids:
96
+ self.documents.pop(uid)
97
+
98
+ def index(self, documents=None):
99
+ # Call base method
100
+ super().index(documents)
101
+
102
+ # Build index if tokens parsed
103
+ if self.wordfreq:
104
+ # Calculate total token frequency
105
+ self.tokens = sum(self.wordfreq.values())
106
+
107
+ # Calculate average frequency per token
108
+ self.avgfreq = self.tokens / len(self.wordfreq.values())
109
+
110
+ # Calculate average document length in tokens
111
+ self.avgdl = self.tokens / self.total
112
+
113
+ # Compute IDF scores
114
+ idfs = self.computeidf(np.array(list(self.docfreq.values())))
115
+ for x, word in enumerate(self.docfreq):
116
+ self.idf[word] = float(idfs[x])
117
+
118
+ # Average IDF score per token
119
+ self.avgidf = float(np.mean(idfs))
120
+
121
+ # Calculate average score across index
122
+ self.avgscore = self.score(self.avgfreq, self.avgidf, self.avgdl)
123
+
124
+ # Filter for tags that appear in at least 1% of the documents
125
+ self.tags = Counter({tag: number for tag, number in self.tags.items() if number >= self.total * 0.005})
126
+
127
+ # Index terms, if available
128
+ if self.terms:
129
+ self.terms.index()
130
+
131
+ def weights(self, tokens):
132
+ # Document length
133
+ length = len(tokens)
134
+
135
+ # Calculate token counts
136
+ freq = self.computefreq(tokens)
137
+ freq = np.array([freq[token] for token in tokens])
138
+
139
+ # Get idf scores
140
+ idf = np.array([self.idf[token] if token in self.idf else self.avgidf for token in tokens])
141
+
142
+ # Calculate score for each token, use as weight
143
+ weights = self.score(freq, idf, length).tolist()
144
+
145
+ # Boost weights of tag tokens to match the largest weight in the list
146
+ if self.tags:
147
+ tags = {token: self.tags[token] for token in tokens if token in self.tags}
148
+ if tags:
149
+ maxWeight = max(weights)
150
+ maxTag = max(tags.values())
151
+
152
+ weights = [max(maxWeight * (tags[tokens[x]] / maxTag), weight) if tokens[x] in tags else weight for x, weight in enumerate(weights)]
153
+
154
+ return weights
155
+
156
+ def search(self, query, limit=3):
157
+ # Check if term index available
158
+ if self.terms:
159
+ # Parse query into terms
160
+ query = self.tokenize(query) if isinstance(query, str) else query
161
+
162
+ # Get topn term query matches
163
+ scores = self.terms.search(query, limit)
164
+
165
+ # Normalize scores, if enabled
166
+ if self.normalize and scores:
167
+ # Calculate max score = best score for this query + average index score
168
+ # Limit max to 6 * average index score
169
+ maxscore = min(scores[0][1] + self.avgscore, 6 * self.avgscore)
170
+
171
+ # Normalize scores between 0 - 1 using maxscore
172
+ scores = [(x, min(score / maxscore, 1.0)) for x, score in scores]
173
+
174
+ # Add content, if available
175
+ return self.results(scores)
176
+
177
+ return None
178
+
179
+ def batchsearch(self, queries, limit=3, threads=True):
180
+ # Calculate number of threads using a thread per 25k records in index
181
+ threads = math.ceil(self.count() / 25000) if isinstance(threads, bool) and threads else int(threads)
182
+ threads = min(max(threads, 1), os.cpu_count())
183
+
184
+ # This method is able to run as multiple threads due to a number of regex and numpy method calls that drop the GIL.
185
+ results = []
186
+ with ThreadPool(threads) as pool:
187
+ for result in pool.starmap(self.search, [(x, limit) for x in queries]):
188
+ results.append(result)
189
+
190
+ return results
191
+
192
+ def count(self):
193
+ return self.terms.count() if self.terms else self.total
194
+
195
+ def load(self, path):
196
+ # Load scoring
197
+ state = Serializer.load(path)
198
+
199
+ # Convert to Counter instances
200
+ for key in ["docfreq", "wordfreq", "tags"]:
201
+ state[key] = Counter(state[key])
202
+
203
+ # Convert documents to dict
204
+ state["documents"] = dict(state["documents"]) if state["documents"] else state["documents"]
205
+
206
+ # Set parameters on this object
207
+ self.__dict__.update(state)
208
+
209
+ # Load terms
210
+ if self.config.get("terms"):
211
+ self.terms = Terms(self.config["terms"], self.score, self.idf)
212
+ self.terms.load(path + ".terms")
213
+
214
+ def save(self, path):
215
+ # Don't serialize following fields
216
+ skipfields = ("config", "terms", "tokenizer")
217
+
218
+ # Get object state
219
+ state = {key: value for key, value in self.__dict__.items() if key not in skipfields}
220
+
221
+ # Update documents to tuples
222
+ state["documents"] = list(state["documents"].items()) if state["documents"] else state["documents"]
223
+
224
+ # Save scoring
225
+ Serializer.save(state, path)
226
+
227
+ # Save terms
228
+ if self.terms:
229
+ self.terms.save(path + ".terms")
230
+
231
+ def close(self):
232
+ if self.terms:
233
+ self.terms.close()
234
+
235
+ def issparse(self):
236
+ return self.terms is not None
237
+
238
+ def isnormalized(self):
239
+ return self.normalize
240
+
241
+ def computefreq(self, tokens):
242
+ """
243
+ Computes token frequency. Used for token weighting.
244
+
245
+ Args:
246
+ tokens: input tokens
247
+
248
+ Returns:
249
+ {token: count}
250
+ """
251
+
252
+ return Counter(tokens)
253
+
254
+ def computeidf(self, freq):
255
+ """
256
+ Computes an idf score for word frequency.
257
+
258
+ Args:
259
+ freq: word frequency
260
+
261
+ Returns:
262
+ idf score
263
+ """
264
+
265
+ return np.log((self.total + 1) / (freq + 1)) + 1
266
+
267
+ # pylint: disable=W0613
268
+ def score(self, freq, idf, length):
269
+ """
270
+ Calculates a score for each token.
271
+
272
+ Args:
273
+ freq: token frequency
274
+ idf: token idf score
275
+ length: total number of tokens in source document
276
+
277
+ Returns:
278
+ token score
279
+ """
280
+
281
+ return idf * np.sqrt(freq) * (1 / np.sqrt(length))
282
+
283
+ def addstats(self, tokens, tags):
284
+ """
285
+ Add tokens and tags to stats.
286
+
287
+ Args:
288
+ tokens: list of tokens
289
+ tags: list of tags
290
+ """
291
+
292
+ # Total number of times token appears, count all tokens
293
+ self.wordfreq.update(tokens)
294
+
295
+ # Total number of documents a token is in, count unique tokens
296
+ self.docfreq.update(set(tokens))
297
+
298
+ # Get list of unique tags
299
+ if tags:
300
+ self.tags.update(tags.split())
301
+
302
+ # Total document count
303
+ self.total += 1
304
+
305
+ def tokenize(self, text):
306
+ """
307
+ Tokenizes text using default tokenizer.
308
+
309
+ Args:
310
+ text: input text
311
+
312
+ Returns:
313
+ tokens
314
+ """
315
+
316
+ # Load tokenizer
317
+ if not self.tokenizer:
318
+ self.tokenizer = self.loadtokenizer()
319
+
320
+ return self.tokenizer(text)
321
+
322
+ def loadtokenizer(self):
323
+ """
324
+ Load default tokenizer.
325
+
326
+ Returns:
327
+ tokenize method
328
+ """
329
+
330
+ # Custom tokenizer settings
331
+ if self.config.get("tokenizer"):
332
+ return Tokenizer(**self.config.get("tokenizer"))
333
+
334
+ # Terms index use a standard tokenizer
335
+ if self.config.get("terms"):
336
+ return Tokenizer()
337
+
338
+ # Standard scoring index without a terms index uses backwards compatible static tokenize method
339
+ return Tokenizer.tokenize
340
+
341
+ def results(self, scores):
342
+ """
343
+ Resolves a list of (id, score) with document content, if available. Otherwise, the original input is returned.
344
+
345
+ Args:
346
+ scores: list of (id, score)
347
+
348
+ Returns:
349
+ resolved results
350
+ """
351
+
352
+ # Convert to Python values
353
+ scores = [(x, float(score)) for x, score in scores]
354
+
355
+ if self.documents:
356
+ return [{"id": x, "text": self.documents[x], "score": score} for x, score in scores]
357
+
358
+ return scores
@@ -0,0 +1,10 @@
1
+ """
2
+ Serialize imports
3
+ """
4
+
5
+ from .base import Serialize
6
+ from .errors import SerializeError
7
+ from .factory import SerializeFactory
8
+ from .messagepack import MessagePack
9
+ from .pickle import Pickle
10
+ from .serializer import Serializer
@@ -0,0 +1,85 @@
1
+ """
2
+ Serialize module
3
+ """
4
+
5
+
6
+ class Serialize:
7
+ """
8
+ Base class for Serialize instances. This class serializes data to files, streams and bytes.
9
+ """
10
+
11
+ def load(self, path):
12
+ """
13
+ Loads data from path.
14
+
15
+ Args:
16
+ path: input path
17
+
18
+ Returns:
19
+ deserialized data
20
+ """
21
+
22
+ with open(path, "rb") as handle:
23
+ return self.loadstream(handle)
24
+
25
+ def save(self, data, path):
26
+ """
27
+ Saves data to path.
28
+
29
+ Args:
30
+ data: data to save
31
+ path: output path
32
+ """
33
+
34
+ with open(path, "wb") as handle:
35
+ self.savestream(data, handle)
36
+
37
+ def loadstream(self, stream):
38
+ """
39
+ Loads data from stream.
40
+
41
+ Args:
42
+ stream: input stream
43
+
44
+ Returns:
45
+ deserialized data
46
+ """
47
+
48
+ raise NotImplementedError
49
+
50
+ def savestream(self, data, stream):
51
+ """
52
+ Saves data to stream.
53
+
54
+ Args:
55
+ data: data to save
56
+ stream: output stream
57
+ """
58
+
59
+ raise NotImplementedError
60
+
61
+ def loadbytes(self, data):
62
+ """
63
+ Loads data from bytes.
64
+
65
+ Args:
66
+ data: input bytes
67
+
68
+ Returns:
69
+ deserialized data
70
+ """
71
+
72
+ raise NotImplementedError
73
+
74
+ def savebytes(self, data):
75
+ """
76
+ Saves data as bytes.
77
+
78
+ Args:
79
+ data: data to save
80
+
81
+ Returns:
82
+ serialized data
83
+ """
84
+
85
+ raise NotImplementedError
@@ -0,0 +1,9 @@
1
+ """
2
+ Errors module
3
+ """
4
+
5
+
6
+ class SerializeError(Exception):
7
+ """
8
+ Raised when data serialization fails
9
+ """
@@ -0,0 +1,29 @@
1
+ """
2
+ Factory module
3
+ """
4
+
5
+ from .messagepack import MessagePack
6
+ from .pickle import Pickle
7
+
8
+
9
+ class SerializeFactory:
10
+ """
11
+ Methods to create data serializers.
12
+ """
13
+
14
+ @staticmethod
15
+ def create(method=None, **kwargs):
16
+ """
17
+ Creates a new Serialize instance.
18
+
19
+ Args:
20
+ method: serialization method
21
+ kwargs: additional keyword arguments to pass to serialize instance
22
+ """
23
+
24
+ # Pickle serialization
25
+ if method == "pickle":
26
+ return Pickle(**kwargs)
27
+
28
+ # Default serialization
29
+ return MessagePack(**kwargs)
@@ -0,0 +1,42 @@
1
+ """
2
+ MessagePack module
3
+ """
4
+
5
+ import msgpack
6
+ from msgpack import Unpacker
7
+ from msgpack.exceptions import ExtraData
8
+
9
+ from .base import Serialize
10
+ from .errors import SerializeError
11
+
12
+
13
+ class MessagePack(Serialize):
14
+ """
15
+ MessagePack serialization.
16
+ """
17
+
18
+ def __init__(self, streaming=False, **kwargs):
19
+ # Parent constructor
20
+ super().__init__()
21
+
22
+ # Streaming unpacker
23
+ self.streaming = streaming
24
+
25
+ # Additional streaming unpacker keyword arguments
26
+ self.kwargs = kwargs
27
+
28
+ def loadstream(self, stream):
29
+ try:
30
+ # Support both streaming and non-streaming unpacking of data
31
+ return Unpacker(stream, **self.kwargs) if self.streaming else msgpack.unpack(stream)
32
+ except ExtraData as e:
33
+ raise SerializeError(e) from e
34
+
35
+ def savestream(self, data, stream):
36
+ msgpack.pack(data, stream)
37
+
38
+ def loadbytes(self, data):
39
+ return msgpack.unpackb(data)
40
+
41
+ def savebytes(self, data):
42
+ return msgpack.packb(data)
@@ -0,0 +1,98 @@
1
+ """
2
+ Pickle module
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ import pickle
8
+ import warnings
9
+
10
+ from .base import Serialize
11
+
12
+ # Logging configuration
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class Pickle(Serialize):
17
+ """
18
+ Pickle serialization.
19
+ """
20
+
21
+ def __init__(self, allowpickle=False):
22
+ """
23
+ Creates a new instance for Pickle serialization.
24
+
25
+ This class ensures the allowpickle parameter or the `ALLOW_PICKLE` environment variable is True. All methods will
26
+ raise errors if this isn't the case.
27
+
28
+ Pickle serialization is OK for local data but it isn't recommended when sharing data externally.
29
+
30
+ Args:
31
+ allowpickle: default pickle allow mode, only True with methods that generate local temporary data
32
+ """
33
+
34
+ # Parent constructor
35
+ super().__init__()
36
+
37
+ # Default allow pickle mode
38
+ self.allowpickle = allowpickle
39
+
40
+ # Current pickle protocol
41
+ self.version = 4
42
+
43
+ def load(self, path):
44
+ # Load pickled data from path, if allowed
45
+ return super().load(path) if self.allow(path) else None
46
+
47
+ def save(self, data, path):
48
+ # Save pickled data to path, if allowed
49
+ if self.allow():
50
+ super().save(data, path)
51
+
52
+ def loadstream(self, stream):
53
+ # Load pickled data from stream, if allowed
54
+ return pickle.load(stream) if self.allow() else None
55
+
56
+ def savestream(self, data, stream):
57
+ # Save pickled data to stream, if allowed
58
+ if self.allow():
59
+ pickle.dump(data, stream, protocol=self.version)
60
+
61
+ def loadbytes(self, data):
62
+ # Load pickled data from bytes, if allowed
63
+ return pickle.loads(data) if self.allow() else None
64
+
65
+ def savebytes(self, data):
66
+ # Save pickled data to stream, if allowed
67
+ return pickle.dumps(data, protocol=self.version) if self.allow() else None
68
+
69
+ def allow(self, path=None):
70
+ """
71
+ Checks if loading and saving pickled data is allowed. Raises an error if it's not allowed.
72
+
73
+ Args:
74
+ path: optional path to add to generated error messages
75
+ """
76
+
77
+ enablepickle = self.allowpickle or os.environ.get("ALLOW_PICKLE", "False") in ("True", "1")
78
+ if not enablepickle:
79
+ raise ValueError(
80
+ (
81
+ "Loading of pickled index data is disabled. "
82
+ f"`{path if path else 'stream'}` was not loaded. "
83
+ "Set the env variable `ALLOW_PICKLE=True` to enable loading pickled index data. "
84
+ "This should only be done for trusted and/or local data."
85
+ )
86
+ )
87
+
88
+ if not self.allowpickle:
89
+ warnings.warn(
90
+ (
91
+ "Loading of pickled data enabled through `ALLOW_PICKLE=True` env variable. "
92
+ "This setting should only be used with trusted and/or local data. "
93
+ "Saving this index will replace pickled index data formats with the latest index formats and remove this warning."
94
+ ),
95
+ RuntimeWarning,
96
+ )
97
+
98
+ return enablepickle
@@ -0,0 +1,46 @@
1
+ """
2
+ Serializer module
3
+ """
4
+
5
+ from .errors import SerializeError
6
+ from .factory import SerializeFactory
7
+
8
+
9
+ class Serializer:
10
+ """
11
+ Methods to serialize and deserialize data.
12
+ """
13
+
14
+ @staticmethod
15
+ def load(path):
16
+ """
17
+ Loads data from path. This method first tries to load the default serialization format.
18
+ If that fails, it will fallback to pickle format for backwards-compatability purposes.
19
+
20
+ Note that loading pickle files requires the env variable `ALLOW_PICKLE=True`.
21
+
22
+ Args:
23
+ path: data to load
24
+
25
+ Returns:
26
+ data
27
+ """
28
+
29
+ try:
30
+ return SerializeFactory.create().load(path)
31
+ except SerializeError:
32
+ # Backwards compatible check for pickled data
33
+ return SerializeFactory.create("pickle").load(path)
34
+
35
+ @staticmethod
36
+ def save(data, path):
37
+ """
38
+ Saves data to path.
39
+
40
+ Args:
41
+ data: data to save
42
+ path: output path
43
+ """
44
+
45
+ # Save using default serialization method
46
+ SerializeFactory.create().save(data, path)
txtai/util/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """
2
+ Utility imports
3
+ """
4
+
5
+ from .resolver import Resolver
6
+ from .sparsearray import SparseArray
7
+ from .template import TemplateFormatter