mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/scoring/terms.py
ADDED
@@ -0,0 +1,499 @@
|
|
1
|
+
"""
|
2
|
+
Terms module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import functools
|
6
|
+
import os
|
7
|
+
import sqlite3
|
8
|
+
import sys
|
9
|
+
|
10
|
+
from array import array
|
11
|
+
from collections import Counter
|
12
|
+
from threading import RLock
|
13
|
+
|
14
|
+
import numpy as np
|
15
|
+
|
16
|
+
|
17
|
+
class Terms:
|
18
|
+
"""
|
19
|
+
Builds, searches and stores memory efficient term frequency sparse arrays for a scoring instance.
|
20
|
+
"""
|
21
|
+
|
22
|
+
# Term frequency sparse arrays
|
23
|
+
CREATE_TERMS = """
|
24
|
+
CREATE TABLE IF NOT EXISTS terms (
|
25
|
+
term TEXT PRIMARY KEY,
|
26
|
+
ids BLOB,
|
27
|
+
freqs BLOB
|
28
|
+
)
|
29
|
+
"""
|
30
|
+
|
31
|
+
INSERT_TERM = "INSERT OR REPLACE INTO terms VALUES (?, ?, ?)"
|
32
|
+
SELECT_TERMS = "SELECT ids, freqs FROM terms WHERE term = ?"
|
33
|
+
|
34
|
+
# Documents table
|
35
|
+
CREATE_DOCUMENTS = """
|
36
|
+
CREATE TABLE IF NOT EXISTS documents (
|
37
|
+
indexid INTEGER PRIMARY KEY,
|
38
|
+
id TEXT,
|
39
|
+
deleted INTEGER,
|
40
|
+
length INTEGER
|
41
|
+
)
|
42
|
+
"""
|
43
|
+
|
44
|
+
DELETE_DOCUMENTS = "DELETE FROM documents"
|
45
|
+
INSERT_DOCUMENT = "INSERT OR REPLACE INTO documents VALUES (?, ?, ?, ?)"
|
46
|
+
SELECT_DOCUMENTS = "SELECT indexid, id, deleted, length FROM documents ORDER BY indexid"
|
47
|
+
|
48
|
+
def __init__(self, config, score, idf):
|
49
|
+
"""
|
50
|
+
Creates a new terms index.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
config: configuration
|
54
|
+
score: score function
|
55
|
+
idf: idf weights
|
56
|
+
"""
|
57
|
+
|
58
|
+
# Terms index configuration
|
59
|
+
self.config = config if isinstance(config, dict) else {}
|
60
|
+
self.cachelimit = self.config.get("cachelimit", 250000000)
|
61
|
+
self.cutoff = self.config.get("cutoff", 0.1)
|
62
|
+
|
63
|
+
# Scoring function
|
64
|
+
self.score, self.idf = score, idf
|
65
|
+
|
66
|
+
# Document attributes
|
67
|
+
self.ids, self.deletes, self.lengths = [], [], array("q")
|
68
|
+
|
69
|
+
# Terms cache
|
70
|
+
self.terms, self.cachesize = {}, 0
|
71
|
+
|
72
|
+
# Terms database
|
73
|
+
self.connection, self.cursor, self.path = None, None, None
|
74
|
+
|
75
|
+
# Database thread lock
|
76
|
+
self.lock = RLock()
|
77
|
+
|
78
|
+
def insert(self, uid, terms):
|
79
|
+
"""
|
80
|
+
Insert term into index.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
uid: document id
|
84
|
+
terms: document terms
|
85
|
+
"""
|
86
|
+
|
87
|
+
# Initialize database, if necessary
|
88
|
+
self.initialize()
|
89
|
+
|
90
|
+
# Get next internal index id
|
91
|
+
indexid = len(self.ids)
|
92
|
+
|
93
|
+
# Calculate term frequency and document length
|
94
|
+
freqs, length = Counter(terms), len(terms)
|
95
|
+
|
96
|
+
# Add document terms
|
97
|
+
for term, count in freqs.items():
|
98
|
+
# Add term entry
|
99
|
+
self.add(indexid, term, count)
|
100
|
+
|
101
|
+
# Each term and freq is a 8-bit signed long long
|
102
|
+
self.cachesize += 16
|
103
|
+
|
104
|
+
# Flush cached terms to the database
|
105
|
+
if self.cachesize >= self.cachelimit:
|
106
|
+
self.index()
|
107
|
+
|
108
|
+
# Save id and length
|
109
|
+
self.ids.append(uid)
|
110
|
+
self.lengths.append(length)
|
111
|
+
|
112
|
+
def delete(self, ids):
|
113
|
+
"""
|
114
|
+
Mark ids as deleted. This prevents deleted results from showing up in search results.
|
115
|
+
The data is not removed from the underlying term frequency sparse arrays.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
ids: ids to delete
|
119
|
+
"""
|
120
|
+
|
121
|
+
# Set index ids as deleted
|
122
|
+
self.deletes.extend([self.ids.index(i) for i in ids])
|
123
|
+
|
124
|
+
def index(self):
|
125
|
+
"""
|
126
|
+
Saves any remaining cached terms to the database.
|
127
|
+
"""
|
128
|
+
|
129
|
+
for term, (nuids, nfreqs) in self.terms.items():
|
130
|
+
# Retrieve existing uids/freqs
|
131
|
+
uids, freqs = self.lookup(term)
|
132
|
+
|
133
|
+
if uids:
|
134
|
+
uids.extend(nuids)
|
135
|
+
freqs.extend(nfreqs)
|
136
|
+
else:
|
137
|
+
uids, freqs = nuids, nfreqs
|
138
|
+
|
139
|
+
# Always save as little endian
|
140
|
+
if sys.byteorder == "big":
|
141
|
+
uids.byteswap()
|
142
|
+
freqs.byteswap()
|
143
|
+
|
144
|
+
# Insert or replace term
|
145
|
+
self.cursor.execute(Terms.INSERT_TERM, [term, uids.tobytes(), freqs.tobytes()])
|
146
|
+
|
147
|
+
# Clear cached weights
|
148
|
+
self.weights.cache_clear()
|
149
|
+
|
150
|
+
# Reset term cache size
|
151
|
+
self.terms, self.cachesize = {}, 0
|
152
|
+
|
153
|
+
def search(self, terms, limit):
|
154
|
+
"""
|
155
|
+
Searches term index a term-at-a-time. Each term frequency sparse array is retrieved
|
156
|
+
and used to calculate term match scores.
|
157
|
+
|
158
|
+
This method calculates term scores in two steps as shown below.
|
159
|
+
|
160
|
+
1. Query and score less common term scores first
|
161
|
+
2. Merge in common term scores for all documents matching the first query
|
162
|
+
|
163
|
+
This is similar to the common terms query in Apache Lucene.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
terms: query terms
|
167
|
+
limit: maximum results
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
list of (id, score)
|
171
|
+
"""
|
172
|
+
|
173
|
+
# Initialize scores array
|
174
|
+
scores = np.zeros(len(self.ids), dtype=np.float32)
|
175
|
+
|
176
|
+
# Score less common terms
|
177
|
+
terms, skipped, hasscores = Counter(terms), {}, False
|
178
|
+
for term, freq in terms.items():
|
179
|
+
# Compute or lookup term weights
|
180
|
+
uids, weights = self.weights(term)
|
181
|
+
if uids is not None:
|
182
|
+
# Term considered common if it appears in more than 10% of index
|
183
|
+
if len(uids) <= self.cutoff * len(self.ids):
|
184
|
+
# Add scores
|
185
|
+
scores[uids] += freq * weights
|
186
|
+
|
187
|
+
# Set flag that scores have been calculated for at least one term
|
188
|
+
hasscores = True
|
189
|
+
else:
|
190
|
+
skipped[term] = freq
|
191
|
+
|
192
|
+
# Merge in common term scores and return top n matches
|
193
|
+
return self.topn(scores, limit, hasscores, skipped)
|
194
|
+
|
195
|
+
def count(self):
|
196
|
+
"""
|
197
|
+
Number of elements in the scoring index.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
count
|
201
|
+
"""
|
202
|
+
|
203
|
+
return len(self.ids) - len(self.deletes)
|
204
|
+
|
205
|
+
def load(self, path):
|
206
|
+
"""
|
207
|
+
Loads terms database from path. This method loads document attributes into memory.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
path: path to read terms database
|
211
|
+
"""
|
212
|
+
|
213
|
+
# Load an existing terms database
|
214
|
+
self.connection = self.connect(path)
|
215
|
+
self.cursor = self.connection.cursor()
|
216
|
+
self.path = path
|
217
|
+
|
218
|
+
# Load document attributes
|
219
|
+
self.ids, self.deletes, self.lengths = [], [], array("q")
|
220
|
+
|
221
|
+
self.cursor.execute(Terms.SELECT_DOCUMENTS)
|
222
|
+
for indexid, uid, deleted, length in self.cursor:
|
223
|
+
# Index id - id
|
224
|
+
self.ids.append(uid)
|
225
|
+
|
226
|
+
# Deleted flag
|
227
|
+
if deleted:
|
228
|
+
self.deletes.append(indexid)
|
229
|
+
|
230
|
+
# Index id - length
|
231
|
+
self.lengths.append(length)
|
232
|
+
|
233
|
+
# Cast ids to int if every id is an integer
|
234
|
+
if all(uid.isdigit() for uid in self.ids):
|
235
|
+
self.ids = [int(uid) for uid in self.ids]
|
236
|
+
|
237
|
+
# Clear cache
|
238
|
+
self.weights.cache_clear()
|
239
|
+
|
240
|
+
def save(self, path):
|
241
|
+
"""
|
242
|
+
Saves terms database to path. This method creates or replaces document attributes into the database.
|
243
|
+
|
244
|
+
Args:
|
245
|
+
path: path to write terms database
|
246
|
+
"""
|
247
|
+
|
248
|
+
# Clear documents table
|
249
|
+
self.cursor.execute(Terms.DELETE_DOCUMENTS)
|
250
|
+
|
251
|
+
# Save document attributes
|
252
|
+
for i, uid in enumerate(self.ids):
|
253
|
+
self.cursor.execute(Terms.INSERT_DOCUMENT, [i, uid, 1 if i in self.deletes else 0, self.lengths[i]])
|
254
|
+
|
255
|
+
# Temporary database
|
256
|
+
if not self.path:
|
257
|
+
# Save temporary database
|
258
|
+
self.connection.commit()
|
259
|
+
|
260
|
+
# Copy data from current to new
|
261
|
+
connection = self.copy(path)
|
262
|
+
|
263
|
+
# Close temporary database
|
264
|
+
self.connection.close()
|
265
|
+
|
266
|
+
# Point connection to new connection
|
267
|
+
self.connection = connection
|
268
|
+
self.cursor = self.connection.cursor()
|
269
|
+
self.path = path
|
270
|
+
|
271
|
+
# Paths are equal, commit changes
|
272
|
+
elif self.path == path:
|
273
|
+
self.connection.commit()
|
274
|
+
|
275
|
+
# New path is different from current path, copy data and continue using current connection
|
276
|
+
else:
|
277
|
+
self.copy(path).close()
|
278
|
+
|
279
|
+
def close(self):
|
280
|
+
"""
|
281
|
+
Close and free resources used by this instance.
|
282
|
+
"""
|
283
|
+
|
284
|
+
# Close connection
|
285
|
+
if self.connection:
|
286
|
+
self.connection.close()
|
287
|
+
|
288
|
+
def initialize(self):
|
289
|
+
"""
|
290
|
+
Creates connection and initial database schema if no connection exists.
|
291
|
+
"""
|
292
|
+
|
293
|
+
if not self.connection:
|
294
|
+
# Create term database
|
295
|
+
self.connection = self.connect()
|
296
|
+
self.cursor = self.connection.cursor()
|
297
|
+
|
298
|
+
# Create initial schema
|
299
|
+
self.cursor.execute(Terms.CREATE_TERMS)
|
300
|
+
self.cursor.execute(Terms.CREATE_DOCUMENTS)
|
301
|
+
|
302
|
+
def connect(self, path=""):
|
303
|
+
"""
|
304
|
+
Creates a new term database connection.
|
305
|
+
|
306
|
+
Args:
|
307
|
+
path: path to term database file
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
connection
|
311
|
+
"""
|
312
|
+
|
313
|
+
connection = sqlite3.connect(path, check_same_thread=False)
|
314
|
+
|
315
|
+
# Enable WAL mode, if necessary
|
316
|
+
if self.config.get("wal"):
|
317
|
+
connection.execute("PRAGMA journal_mode=WAL")
|
318
|
+
|
319
|
+
return connection
|
320
|
+
|
321
|
+
def copy(self, path):
|
322
|
+
"""
|
323
|
+
Copies content from current terms database into target.
|
324
|
+
|
325
|
+
Args:
|
326
|
+
path: target database path
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
new database connection
|
330
|
+
"""
|
331
|
+
|
332
|
+
# Delete existing file, if necessary
|
333
|
+
if os.path.exists(path):
|
334
|
+
os.remove(path)
|
335
|
+
|
336
|
+
# Create new connection
|
337
|
+
connection = self.connect(path)
|
338
|
+
|
339
|
+
if self.connection.in_transaction:
|
340
|
+
# The backup call will hang if there are uncommitted changes, need to copy over
|
341
|
+
# with iterdump (which is much slower)
|
342
|
+
for sql in self.connection.iterdump():
|
343
|
+
connection.execute(sql)
|
344
|
+
else:
|
345
|
+
# Database is up to date, can do a more efficient copy with SQLite C API
|
346
|
+
self.connection.backup(connection)
|
347
|
+
|
348
|
+
return connection
|
349
|
+
|
350
|
+
def add(self, indexid, term, freq):
|
351
|
+
"""
|
352
|
+
Adds a term frequency entry.
|
353
|
+
|
354
|
+
Args:
|
355
|
+
indexid: internal index id
|
356
|
+
term: term
|
357
|
+
freq: term frequency
|
358
|
+
"""
|
359
|
+
|
360
|
+
# Get or create uids and freqs arrays
|
361
|
+
if term not in self.terms:
|
362
|
+
self.terms[term] = (array("q"), array("q"))
|
363
|
+
|
364
|
+
# Append uids and freqs
|
365
|
+
ids, freqs = self.terms[term]
|
366
|
+
ids.append(indexid)
|
367
|
+
freqs.append(freq)
|
368
|
+
|
369
|
+
def lookup(self, term):
|
370
|
+
"""
|
371
|
+
Retrieves a term frequency sparse array.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
term: term to lookup
|
375
|
+
|
376
|
+
Returns:
|
377
|
+
term frequency sparse array
|
378
|
+
"""
|
379
|
+
|
380
|
+
uids, freqs = None, None
|
381
|
+
|
382
|
+
result = self.cursor.execute(Terms.SELECT_TERMS, [term]).fetchone()
|
383
|
+
if result:
|
384
|
+
uids, freqs = (array("q"), array("q"))
|
385
|
+
uids.frombytes(result[0])
|
386
|
+
freqs.frombytes(result[1])
|
387
|
+
|
388
|
+
# Storage format is always little endian
|
389
|
+
if sys.byteorder == "big":
|
390
|
+
uids.byteswap()
|
391
|
+
freqs.byteswap()
|
392
|
+
|
393
|
+
return uids, freqs
|
394
|
+
|
395
|
+
@functools.lru_cache(maxsize=500)
|
396
|
+
def weights(self, term):
|
397
|
+
"""
|
398
|
+
Computes a term weights sparse array for term. This method is wrapped with a least recently used cache,
|
399
|
+
which will return common term weights from the cache.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
term: term
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
term weights sparse array
|
406
|
+
"""
|
407
|
+
|
408
|
+
lengths = np.frombuffer(self.lengths, dtype=np.int64)
|
409
|
+
|
410
|
+
with self.lock:
|
411
|
+
uids, freqs = self.lookup(term)
|
412
|
+
weights = None
|
413
|
+
|
414
|
+
if uids:
|
415
|
+
uids = np.frombuffer(uids, dtype=np.int64)
|
416
|
+
weights = self.score(np.frombuffer(freqs, dtype=np.int64), self.idf[term], lengths[uids]).astype(np.float32)
|
417
|
+
|
418
|
+
return uids, weights
|
419
|
+
|
420
|
+
def topn(self, scores, limit, hasscores, skipped):
|
421
|
+
"""
|
422
|
+
Get topn scores from an partial scores array.
|
423
|
+
|
424
|
+
Args:
|
425
|
+
scores: partial scores array with scores for less common terms
|
426
|
+
limit: maximum results
|
427
|
+
hasscores: True if partial scores array has any nonzero scores, False otherwise
|
428
|
+
skipped: terms skipped in initial query
|
429
|
+
|
430
|
+
Returns:
|
431
|
+
topn scores
|
432
|
+
"""
|
433
|
+
|
434
|
+
# Calculate topn candidates to consider
|
435
|
+
# Require at least one positive score, set topn to smaller of limit * 5 or number of scores
|
436
|
+
topn = min(len(scores), limit * 5)
|
437
|
+
|
438
|
+
# Get topn candidates, allows for score shifting when adding in common term scores
|
439
|
+
matches = self.candidates(scores, topn)
|
440
|
+
|
441
|
+
# Merge in scores for more common terms
|
442
|
+
self.merge(scores, matches, hasscores, skipped)
|
443
|
+
|
444
|
+
# Get topn candidates since it was initially skipped above
|
445
|
+
if not hasscores:
|
446
|
+
matches = self.candidates(scores, topn)
|
447
|
+
|
448
|
+
# Reorder matches using updated scores
|
449
|
+
matches = matches[np.argsort(-scores[matches])]
|
450
|
+
|
451
|
+
# Combine ids with scores. Require score > 0.
|
452
|
+
return [(self.ids[x], float(scores[x])) for x in matches[:limit] if scores[x] > 0]
|
453
|
+
|
454
|
+
def merge(self, scores, matches, hasscores, terms):
|
455
|
+
"""
|
456
|
+
Merges common term scores into scores array.
|
457
|
+
|
458
|
+
Args:
|
459
|
+
scores: partial scores array
|
460
|
+
matches: current matches, if any
|
461
|
+
hasscores: True if scores has current matches, False otherwise
|
462
|
+
terms: common terms
|
463
|
+
"""
|
464
|
+
|
465
|
+
for term, freq in terms.items():
|
466
|
+
# Compute or lookup term weights
|
467
|
+
uids, weights = self.weights(term)
|
468
|
+
|
469
|
+
# Filter to topn matches when partial scores array has nonzero scores
|
470
|
+
if hasscores:
|
471
|
+
# Find indices in match ids for uids
|
472
|
+
indices = np.searchsorted(uids, matches)
|
473
|
+
|
474
|
+
# Filter matches that don't exist in uids
|
475
|
+
indices = [x for i, x in enumerate(indices) if x < len(uids) and uids[x] == matches[i]]
|
476
|
+
|
477
|
+
# Filter to matching uids and weights
|
478
|
+
uids, weights = uids[indices], weights[indices]
|
479
|
+
|
480
|
+
# Update scores
|
481
|
+
scores[uids] += freq * weights
|
482
|
+
|
483
|
+
def candidates(self, scores, topn):
|
484
|
+
"""
|
485
|
+
Gets the topn scored candidates. This method ignores deleted documents.
|
486
|
+
|
487
|
+
Args:
|
488
|
+
scores: scores array
|
489
|
+
topn: topn elements
|
490
|
+
|
491
|
+
Returns:
|
492
|
+
topn scored candidates
|
493
|
+
"""
|
494
|
+
|
495
|
+
# Clear deletes
|
496
|
+
scores[self.deletes] = 0
|
497
|
+
|
498
|
+
# Get topn candidates
|
499
|
+
return np.argpartition(scores, -topn)[-topn:]
|