mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/scoring/tfidf.py
ADDED
@@ -0,0 +1,358 @@
|
|
1
|
+
"""
|
2
|
+
TFIDF module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import math
|
6
|
+
import os
|
7
|
+
|
8
|
+
from collections import Counter
|
9
|
+
from multiprocessing.pool import ThreadPool
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
|
13
|
+
from ..pipeline import Tokenizer
|
14
|
+
from ..serialize import Serializer
|
15
|
+
|
16
|
+
from .base import Scoring
|
17
|
+
from .terms import Terms
|
18
|
+
|
19
|
+
|
20
|
+
class TFIDF(Scoring):
|
21
|
+
"""
|
22
|
+
Term frequency-inverse document frequency (TF-IDF) scoring.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, config=None):
|
26
|
+
super().__init__(config)
|
27
|
+
|
28
|
+
# Document stats
|
29
|
+
self.total = 0
|
30
|
+
self.tokens = 0
|
31
|
+
self.avgdl = 0
|
32
|
+
|
33
|
+
# Word frequency
|
34
|
+
self.docfreq = Counter()
|
35
|
+
self.wordfreq = Counter()
|
36
|
+
self.avgfreq = 0
|
37
|
+
|
38
|
+
# IDF index
|
39
|
+
self.idf = {}
|
40
|
+
self.avgidf = 0
|
41
|
+
|
42
|
+
# Tag boosting
|
43
|
+
self.tags = Counter()
|
44
|
+
|
45
|
+
# Tokenizer, lazily loaded as needed
|
46
|
+
self.tokenizer = None
|
47
|
+
|
48
|
+
# Term index
|
49
|
+
self.terms = Terms(self.config["terms"], self.score, self.idf) if self.config.get("terms") else None
|
50
|
+
|
51
|
+
# Document data
|
52
|
+
self.documents = {} if self.config.get("content") else None
|
53
|
+
|
54
|
+
# Normalize scores
|
55
|
+
self.normalize = self.config.get("normalize")
|
56
|
+
self.avgscore = None
|
57
|
+
|
58
|
+
def insert(self, documents, index=None, checkpoint=None):
|
59
|
+
# Insert documents, calculate word frequency, total tokens and total documents
|
60
|
+
for uid, document, tags in documents:
|
61
|
+
# Extract text, if necessary
|
62
|
+
if isinstance(document, dict):
|
63
|
+
document = document.get(self.text, document.get(self.object))
|
64
|
+
|
65
|
+
if document is not None:
|
66
|
+
# If index is passed, use indexid, otherwise use id
|
67
|
+
uid = index if index is not None else uid
|
68
|
+
|
69
|
+
# Add entry to index if the data type is accepted
|
70
|
+
if isinstance(document, (str, list)):
|
71
|
+
# Store content
|
72
|
+
if self.documents is not None:
|
73
|
+
self.documents[uid] = document
|
74
|
+
|
75
|
+
# Convert to tokens, if necessary
|
76
|
+
tokens = self.tokenize(document) if isinstance(document, str) else document
|
77
|
+
|
78
|
+
# Add tokens for id to term index
|
79
|
+
if self.terms is not None:
|
80
|
+
self.terms.insert(uid, tokens)
|
81
|
+
|
82
|
+
# Add tokens and tags to stats
|
83
|
+
self.addstats(tokens, tags)
|
84
|
+
|
85
|
+
# Increment index
|
86
|
+
index = index + 1 if index is not None else None
|
87
|
+
|
88
|
+
def delete(self, ids):
|
89
|
+
# Delete from terms index
|
90
|
+
if self.terms:
|
91
|
+
self.terms.delete(ids)
|
92
|
+
|
93
|
+
# Delete content
|
94
|
+
if self.documents:
|
95
|
+
for uid in ids:
|
96
|
+
self.documents.pop(uid)
|
97
|
+
|
98
|
+
def index(self, documents=None):
|
99
|
+
# Call base method
|
100
|
+
super().index(documents)
|
101
|
+
|
102
|
+
# Build index if tokens parsed
|
103
|
+
if self.wordfreq:
|
104
|
+
# Calculate total token frequency
|
105
|
+
self.tokens = sum(self.wordfreq.values())
|
106
|
+
|
107
|
+
# Calculate average frequency per token
|
108
|
+
self.avgfreq = self.tokens / len(self.wordfreq.values())
|
109
|
+
|
110
|
+
# Calculate average document length in tokens
|
111
|
+
self.avgdl = self.tokens / self.total
|
112
|
+
|
113
|
+
# Compute IDF scores
|
114
|
+
idfs = self.computeidf(np.array(list(self.docfreq.values())))
|
115
|
+
for x, word in enumerate(self.docfreq):
|
116
|
+
self.idf[word] = float(idfs[x])
|
117
|
+
|
118
|
+
# Average IDF score per token
|
119
|
+
self.avgidf = float(np.mean(idfs))
|
120
|
+
|
121
|
+
# Calculate average score across index
|
122
|
+
self.avgscore = self.score(self.avgfreq, self.avgidf, self.avgdl)
|
123
|
+
|
124
|
+
# Filter for tags that appear in at least 1% of the documents
|
125
|
+
self.tags = Counter({tag: number for tag, number in self.tags.items() if number >= self.total * 0.005})
|
126
|
+
|
127
|
+
# Index terms, if available
|
128
|
+
if self.terms:
|
129
|
+
self.terms.index()
|
130
|
+
|
131
|
+
def weights(self, tokens):
|
132
|
+
# Document length
|
133
|
+
length = len(tokens)
|
134
|
+
|
135
|
+
# Calculate token counts
|
136
|
+
freq = self.computefreq(tokens)
|
137
|
+
freq = np.array([freq[token] for token in tokens])
|
138
|
+
|
139
|
+
# Get idf scores
|
140
|
+
idf = np.array([self.idf[token] if token in self.idf else self.avgidf for token in tokens])
|
141
|
+
|
142
|
+
# Calculate score for each token, use as weight
|
143
|
+
weights = self.score(freq, idf, length).tolist()
|
144
|
+
|
145
|
+
# Boost weights of tag tokens to match the largest weight in the list
|
146
|
+
if self.tags:
|
147
|
+
tags = {token: self.tags[token] for token in tokens if token in self.tags}
|
148
|
+
if tags:
|
149
|
+
maxWeight = max(weights)
|
150
|
+
maxTag = max(tags.values())
|
151
|
+
|
152
|
+
weights = [max(maxWeight * (tags[tokens[x]] / maxTag), weight) if tokens[x] in tags else weight for x, weight in enumerate(weights)]
|
153
|
+
|
154
|
+
return weights
|
155
|
+
|
156
|
+
def search(self, query, limit=3):
|
157
|
+
# Check if term index available
|
158
|
+
if self.terms:
|
159
|
+
# Parse query into terms
|
160
|
+
query = self.tokenize(query) if isinstance(query, str) else query
|
161
|
+
|
162
|
+
# Get topn term query matches
|
163
|
+
scores = self.terms.search(query, limit)
|
164
|
+
|
165
|
+
# Normalize scores, if enabled
|
166
|
+
if self.normalize and scores:
|
167
|
+
# Calculate max score = best score for this query + average index score
|
168
|
+
# Limit max to 6 * average index score
|
169
|
+
maxscore = min(scores[0][1] + self.avgscore, 6 * self.avgscore)
|
170
|
+
|
171
|
+
# Normalize scores between 0 - 1 using maxscore
|
172
|
+
scores = [(x, min(score / maxscore, 1.0)) for x, score in scores]
|
173
|
+
|
174
|
+
# Add content, if available
|
175
|
+
return self.results(scores)
|
176
|
+
|
177
|
+
return None
|
178
|
+
|
179
|
+
def batchsearch(self, queries, limit=3, threads=True):
|
180
|
+
# Calculate number of threads using a thread per 25k records in index
|
181
|
+
threads = math.ceil(self.count() / 25000) if isinstance(threads, bool) and threads else int(threads)
|
182
|
+
threads = min(max(threads, 1), os.cpu_count())
|
183
|
+
|
184
|
+
# This method is able to run as multiple threads due to a number of regex and numpy method calls that drop the GIL.
|
185
|
+
results = []
|
186
|
+
with ThreadPool(threads) as pool:
|
187
|
+
for result in pool.starmap(self.search, [(x, limit) for x in queries]):
|
188
|
+
results.append(result)
|
189
|
+
|
190
|
+
return results
|
191
|
+
|
192
|
+
def count(self):
|
193
|
+
return self.terms.count() if self.terms else self.total
|
194
|
+
|
195
|
+
def load(self, path):
|
196
|
+
# Load scoring
|
197
|
+
state = Serializer.load(path)
|
198
|
+
|
199
|
+
# Convert to Counter instances
|
200
|
+
for key in ["docfreq", "wordfreq", "tags"]:
|
201
|
+
state[key] = Counter(state[key])
|
202
|
+
|
203
|
+
# Convert documents to dict
|
204
|
+
state["documents"] = dict(state["documents"]) if state["documents"] else state["documents"]
|
205
|
+
|
206
|
+
# Set parameters on this object
|
207
|
+
self.__dict__.update(state)
|
208
|
+
|
209
|
+
# Load terms
|
210
|
+
if self.config.get("terms"):
|
211
|
+
self.terms = Terms(self.config["terms"], self.score, self.idf)
|
212
|
+
self.terms.load(path + ".terms")
|
213
|
+
|
214
|
+
def save(self, path):
|
215
|
+
# Don't serialize following fields
|
216
|
+
skipfields = ("config", "terms", "tokenizer")
|
217
|
+
|
218
|
+
# Get object state
|
219
|
+
state = {key: value for key, value in self.__dict__.items() if key not in skipfields}
|
220
|
+
|
221
|
+
# Update documents to tuples
|
222
|
+
state["documents"] = list(state["documents"].items()) if state["documents"] else state["documents"]
|
223
|
+
|
224
|
+
# Save scoring
|
225
|
+
Serializer.save(state, path)
|
226
|
+
|
227
|
+
# Save terms
|
228
|
+
if self.terms:
|
229
|
+
self.terms.save(path + ".terms")
|
230
|
+
|
231
|
+
def close(self):
|
232
|
+
if self.terms:
|
233
|
+
self.terms.close()
|
234
|
+
|
235
|
+
def issparse(self):
|
236
|
+
return self.terms is not None
|
237
|
+
|
238
|
+
def isnormalized(self):
|
239
|
+
return self.normalize
|
240
|
+
|
241
|
+
def computefreq(self, tokens):
|
242
|
+
"""
|
243
|
+
Computes token frequency. Used for token weighting.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
tokens: input tokens
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
{token: count}
|
250
|
+
"""
|
251
|
+
|
252
|
+
return Counter(tokens)
|
253
|
+
|
254
|
+
def computeidf(self, freq):
|
255
|
+
"""
|
256
|
+
Computes an idf score for word frequency.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
freq: word frequency
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
idf score
|
263
|
+
"""
|
264
|
+
|
265
|
+
return np.log((self.total + 1) / (freq + 1)) + 1
|
266
|
+
|
267
|
+
# pylint: disable=W0613
|
268
|
+
def score(self, freq, idf, length):
|
269
|
+
"""
|
270
|
+
Calculates a score for each token.
|
271
|
+
|
272
|
+
Args:
|
273
|
+
freq: token frequency
|
274
|
+
idf: token idf score
|
275
|
+
length: total number of tokens in source document
|
276
|
+
|
277
|
+
Returns:
|
278
|
+
token score
|
279
|
+
"""
|
280
|
+
|
281
|
+
return idf * np.sqrt(freq) * (1 / np.sqrt(length))
|
282
|
+
|
283
|
+
def addstats(self, tokens, tags):
|
284
|
+
"""
|
285
|
+
Add tokens and tags to stats.
|
286
|
+
|
287
|
+
Args:
|
288
|
+
tokens: list of tokens
|
289
|
+
tags: list of tags
|
290
|
+
"""
|
291
|
+
|
292
|
+
# Total number of times token appears, count all tokens
|
293
|
+
self.wordfreq.update(tokens)
|
294
|
+
|
295
|
+
# Total number of documents a token is in, count unique tokens
|
296
|
+
self.docfreq.update(set(tokens))
|
297
|
+
|
298
|
+
# Get list of unique tags
|
299
|
+
if tags:
|
300
|
+
self.tags.update(tags.split())
|
301
|
+
|
302
|
+
# Total document count
|
303
|
+
self.total += 1
|
304
|
+
|
305
|
+
def tokenize(self, text):
|
306
|
+
"""
|
307
|
+
Tokenizes text using default tokenizer.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
text: input text
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
tokens
|
314
|
+
"""
|
315
|
+
|
316
|
+
# Load tokenizer
|
317
|
+
if not self.tokenizer:
|
318
|
+
self.tokenizer = self.loadtokenizer()
|
319
|
+
|
320
|
+
return self.tokenizer(text)
|
321
|
+
|
322
|
+
def loadtokenizer(self):
|
323
|
+
"""
|
324
|
+
Load default tokenizer.
|
325
|
+
|
326
|
+
Returns:
|
327
|
+
tokenize method
|
328
|
+
"""
|
329
|
+
|
330
|
+
# Custom tokenizer settings
|
331
|
+
if self.config.get("tokenizer"):
|
332
|
+
return Tokenizer(**self.config.get("tokenizer"))
|
333
|
+
|
334
|
+
# Terms index use a standard tokenizer
|
335
|
+
if self.config.get("terms"):
|
336
|
+
return Tokenizer()
|
337
|
+
|
338
|
+
# Standard scoring index without a terms index uses backwards compatible static tokenize method
|
339
|
+
return Tokenizer.tokenize
|
340
|
+
|
341
|
+
def results(self, scores):
|
342
|
+
"""
|
343
|
+
Resolves a list of (id, score) with document content, if available. Otherwise, the original input is returned.
|
344
|
+
|
345
|
+
Args:
|
346
|
+
scores: list of (id, score)
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
resolved results
|
350
|
+
"""
|
351
|
+
|
352
|
+
# Convert to Python values
|
353
|
+
scores = [(x, float(score)) for x, score in scores]
|
354
|
+
|
355
|
+
if self.documents:
|
356
|
+
return [{"id": x, "text": self.documents[x], "score": score} for x, score in scores]
|
357
|
+
|
358
|
+
return scores
|
txtai/serialize/base.py
ADDED
@@ -0,0 +1,85 @@
|
|
1
|
+
"""
|
2
|
+
Serialize module
|
3
|
+
"""
|
4
|
+
|
5
|
+
|
6
|
+
class Serialize:
|
7
|
+
"""
|
8
|
+
Base class for Serialize instances. This class serializes data to files, streams and bytes.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def load(self, path):
|
12
|
+
"""
|
13
|
+
Loads data from path.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
path: input path
|
17
|
+
|
18
|
+
Returns:
|
19
|
+
deserialized data
|
20
|
+
"""
|
21
|
+
|
22
|
+
with open(path, "rb") as handle:
|
23
|
+
return self.loadstream(handle)
|
24
|
+
|
25
|
+
def save(self, data, path):
|
26
|
+
"""
|
27
|
+
Saves data to path.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
data: data to save
|
31
|
+
path: output path
|
32
|
+
"""
|
33
|
+
|
34
|
+
with open(path, "wb") as handle:
|
35
|
+
self.savestream(data, handle)
|
36
|
+
|
37
|
+
def loadstream(self, stream):
|
38
|
+
"""
|
39
|
+
Loads data from stream.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
stream: input stream
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
deserialized data
|
46
|
+
"""
|
47
|
+
|
48
|
+
raise NotImplementedError
|
49
|
+
|
50
|
+
def savestream(self, data, stream):
|
51
|
+
"""
|
52
|
+
Saves data to stream.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
data: data to save
|
56
|
+
stream: output stream
|
57
|
+
"""
|
58
|
+
|
59
|
+
raise NotImplementedError
|
60
|
+
|
61
|
+
def loadbytes(self, data):
|
62
|
+
"""
|
63
|
+
Loads data from bytes.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
data: input bytes
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
deserialized data
|
70
|
+
"""
|
71
|
+
|
72
|
+
raise NotImplementedError
|
73
|
+
|
74
|
+
def savebytes(self, data):
|
75
|
+
"""
|
76
|
+
Saves data as bytes.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
data: data to save
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
serialized data
|
83
|
+
"""
|
84
|
+
|
85
|
+
raise NotImplementedError
|
@@ -0,0 +1,29 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .messagepack import MessagePack
|
6
|
+
from .pickle import Pickle
|
7
|
+
|
8
|
+
|
9
|
+
class SerializeFactory:
|
10
|
+
"""
|
11
|
+
Methods to create data serializers.
|
12
|
+
"""
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def create(method=None, **kwargs):
|
16
|
+
"""
|
17
|
+
Creates a new Serialize instance.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
method: serialization method
|
21
|
+
kwargs: additional keyword arguments to pass to serialize instance
|
22
|
+
"""
|
23
|
+
|
24
|
+
# Pickle serialization
|
25
|
+
if method == "pickle":
|
26
|
+
return Pickle(**kwargs)
|
27
|
+
|
28
|
+
# Default serialization
|
29
|
+
return MessagePack(**kwargs)
|
@@ -0,0 +1,42 @@
|
|
1
|
+
"""
|
2
|
+
MessagePack module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import msgpack
|
6
|
+
from msgpack import Unpacker
|
7
|
+
from msgpack.exceptions import ExtraData
|
8
|
+
|
9
|
+
from .base import Serialize
|
10
|
+
from .errors import SerializeError
|
11
|
+
|
12
|
+
|
13
|
+
class MessagePack(Serialize):
|
14
|
+
"""
|
15
|
+
MessagePack serialization.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, streaming=False, **kwargs):
|
19
|
+
# Parent constructor
|
20
|
+
super().__init__()
|
21
|
+
|
22
|
+
# Streaming unpacker
|
23
|
+
self.streaming = streaming
|
24
|
+
|
25
|
+
# Additional streaming unpacker keyword arguments
|
26
|
+
self.kwargs = kwargs
|
27
|
+
|
28
|
+
def loadstream(self, stream):
|
29
|
+
try:
|
30
|
+
# Support both streaming and non-streaming unpacking of data
|
31
|
+
return Unpacker(stream, **self.kwargs) if self.streaming else msgpack.unpack(stream)
|
32
|
+
except ExtraData as e:
|
33
|
+
raise SerializeError(e) from e
|
34
|
+
|
35
|
+
def savestream(self, data, stream):
|
36
|
+
msgpack.pack(data, stream)
|
37
|
+
|
38
|
+
def loadbytes(self, data):
|
39
|
+
return msgpack.unpackb(data)
|
40
|
+
|
41
|
+
def savebytes(self, data):
|
42
|
+
return msgpack.packb(data)
|
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
Pickle module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import logging
|
7
|
+
import pickle
|
8
|
+
import warnings
|
9
|
+
|
10
|
+
from .base import Serialize
|
11
|
+
|
12
|
+
# Logging configuration
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class Pickle(Serialize):
|
17
|
+
"""
|
18
|
+
Pickle serialization.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, allowpickle=False):
|
22
|
+
"""
|
23
|
+
Creates a new instance for Pickle serialization.
|
24
|
+
|
25
|
+
This class ensures the allowpickle parameter or the `ALLOW_PICKLE` environment variable is True. All methods will
|
26
|
+
raise errors if this isn't the case.
|
27
|
+
|
28
|
+
Pickle serialization is OK for local data but it isn't recommended when sharing data externally.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
allowpickle: default pickle allow mode, only True with methods that generate local temporary data
|
32
|
+
"""
|
33
|
+
|
34
|
+
# Parent constructor
|
35
|
+
super().__init__()
|
36
|
+
|
37
|
+
# Default allow pickle mode
|
38
|
+
self.allowpickle = allowpickle
|
39
|
+
|
40
|
+
# Current pickle protocol
|
41
|
+
self.version = 4
|
42
|
+
|
43
|
+
def load(self, path):
|
44
|
+
# Load pickled data from path, if allowed
|
45
|
+
return super().load(path) if self.allow(path) else None
|
46
|
+
|
47
|
+
def save(self, data, path):
|
48
|
+
# Save pickled data to path, if allowed
|
49
|
+
if self.allow():
|
50
|
+
super().save(data, path)
|
51
|
+
|
52
|
+
def loadstream(self, stream):
|
53
|
+
# Load pickled data from stream, if allowed
|
54
|
+
return pickle.load(stream) if self.allow() else None
|
55
|
+
|
56
|
+
def savestream(self, data, stream):
|
57
|
+
# Save pickled data to stream, if allowed
|
58
|
+
if self.allow():
|
59
|
+
pickle.dump(data, stream, protocol=self.version)
|
60
|
+
|
61
|
+
def loadbytes(self, data):
|
62
|
+
# Load pickled data from bytes, if allowed
|
63
|
+
return pickle.loads(data) if self.allow() else None
|
64
|
+
|
65
|
+
def savebytes(self, data):
|
66
|
+
# Save pickled data to stream, if allowed
|
67
|
+
return pickle.dumps(data, protocol=self.version) if self.allow() else None
|
68
|
+
|
69
|
+
def allow(self, path=None):
|
70
|
+
"""
|
71
|
+
Checks if loading and saving pickled data is allowed. Raises an error if it's not allowed.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
path: optional path to add to generated error messages
|
75
|
+
"""
|
76
|
+
|
77
|
+
enablepickle = self.allowpickle or os.environ.get("ALLOW_PICKLE", "False") in ("True", "1")
|
78
|
+
if not enablepickle:
|
79
|
+
raise ValueError(
|
80
|
+
(
|
81
|
+
"Loading of pickled index data is disabled. "
|
82
|
+
f"`{path if path else 'stream'}` was not loaded. "
|
83
|
+
"Set the env variable `ALLOW_PICKLE=True` to enable loading pickled index data. "
|
84
|
+
"This should only be done for trusted and/or local data."
|
85
|
+
)
|
86
|
+
)
|
87
|
+
|
88
|
+
if not self.allowpickle:
|
89
|
+
warnings.warn(
|
90
|
+
(
|
91
|
+
"Loading of pickled data enabled through `ALLOW_PICKLE=True` env variable. "
|
92
|
+
"This setting should only be used with trusted and/or local data. "
|
93
|
+
"Saving this index will replace pickled index data formats with the latest index formats and remove this warning."
|
94
|
+
),
|
95
|
+
RuntimeWarning,
|
96
|
+
)
|
97
|
+
|
98
|
+
return enablepickle
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""
|
2
|
+
Serializer module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .errors import SerializeError
|
6
|
+
from .factory import SerializeFactory
|
7
|
+
|
8
|
+
|
9
|
+
class Serializer:
|
10
|
+
"""
|
11
|
+
Methods to serialize and deserialize data.
|
12
|
+
"""
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def load(path):
|
16
|
+
"""
|
17
|
+
Loads data from path. This method first tries to load the default serialization format.
|
18
|
+
If that fails, it will fallback to pickle format for backwards-compatability purposes.
|
19
|
+
|
20
|
+
Note that loading pickle files requires the env variable `ALLOW_PICKLE=True`.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
path: data to load
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
data
|
27
|
+
"""
|
28
|
+
|
29
|
+
try:
|
30
|
+
return SerializeFactory.create().load(path)
|
31
|
+
except SerializeError:
|
32
|
+
# Backwards compatible check for pickled data
|
33
|
+
return SerializeFactory.create("pickle").load(path)
|
34
|
+
|
35
|
+
@staticmethod
|
36
|
+
def save(data, path):
|
37
|
+
"""
|
38
|
+
Saves data to path.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
data: data to save
|
42
|
+
path: output path
|
43
|
+
"""
|
44
|
+
|
45
|
+
# Save using default serialization method
|
46
|
+
SerializeFactory.create().save(data, path)
|