mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/scoring/factory.py
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..util import Resolver
|
6
|
+
|
7
|
+
from .bm25 import BM25
|
8
|
+
from .pgtext import PGText
|
9
|
+
from .sif import SIF
|
10
|
+
from .sparse import Sparse
|
11
|
+
from .tfidf import TFIDF
|
12
|
+
|
13
|
+
|
14
|
+
class ScoringFactory:
|
15
|
+
"""
|
16
|
+
Methods to create Scoring indexes.
|
17
|
+
"""
|
18
|
+
|
19
|
+
@staticmethod
|
20
|
+
def create(config, models=None):
|
21
|
+
"""
|
22
|
+
Factory method to construct a Scoring instance.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
config: scoring configuration parameters
|
26
|
+
models: models cache
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
Scoring
|
30
|
+
"""
|
31
|
+
|
32
|
+
# Scoring instance
|
33
|
+
scoring = None
|
34
|
+
|
35
|
+
# Support string and dict configuration
|
36
|
+
if isinstance(config, str):
|
37
|
+
config = {"method": config}
|
38
|
+
|
39
|
+
# Get scoring method
|
40
|
+
method = config.get("method", "bm25")
|
41
|
+
|
42
|
+
if method == "bm25":
|
43
|
+
scoring = BM25(config)
|
44
|
+
elif method == "pgtext":
|
45
|
+
scoring = PGText(config)
|
46
|
+
elif method == "sif":
|
47
|
+
scoring = SIF(config)
|
48
|
+
elif method == "sparse":
|
49
|
+
scoring = Sparse(config, models)
|
50
|
+
elif method == "tfidf":
|
51
|
+
scoring = TFIDF(config)
|
52
|
+
else:
|
53
|
+
# Resolve custom method
|
54
|
+
scoring = ScoringFactory.resolve(method, config)
|
55
|
+
|
56
|
+
# Store config back
|
57
|
+
config["method"] = method
|
58
|
+
|
59
|
+
return scoring
|
60
|
+
|
61
|
+
@staticmethod
|
62
|
+
def issparse(config):
|
63
|
+
"""
|
64
|
+
Checks if this scoring configuration builds a sparse index.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
config: scoring configuration
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
True if this config is for a sparse index
|
71
|
+
"""
|
72
|
+
|
73
|
+
# Types that are always a sparse index
|
74
|
+
indexes = ["pgtext", "sparse"]
|
75
|
+
|
76
|
+
# True if this config is for a sparse index
|
77
|
+
return config and isinstance(config, dict) and (config.get("method") in indexes or config.get("terms"))
|
78
|
+
|
79
|
+
@staticmethod
|
80
|
+
def resolve(backend, config):
|
81
|
+
"""
|
82
|
+
Attempt to resolve a custom backend.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
backend: backend class
|
86
|
+
config: index configuration parameters
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
Scoring
|
90
|
+
"""
|
91
|
+
|
92
|
+
try:
|
93
|
+
return Resolver()(backend)(config)
|
94
|
+
except Exception as e:
|
95
|
+
raise ImportError(f"Unable to resolve scoring backend: '{backend}'") from e
|
txtai/scoring/pgtext.py
ADDED
@@ -0,0 +1,181 @@
|
|
1
|
+
"""
|
2
|
+
PGText module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
# Conditional import
|
8
|
+
try:
|
9
|
+
from sqlalchemy import create_engine, desc, delete, func, text
|
10
|
+
from sqlalchemy import Column, Computed, Index, Integer, MetaData, StaticPool, Table, Text
|
11
|
+
from sqlalchemy.dialects.postgresql import TSVECTOR
|
12
|
+
from sqlalchemy.orm import Session
|
13
|
+
from sqlalchemy.schema import CreateSchema
|
14
|
+
|
15
|
+
PGTEXT = True
|
16
|
+
except ImportError:
|
17
|
+
PGTEXT = False
|
18
|
+
|
19
|
+
from .base import Scoring
|
20
|
+
|
21
|
+
|
22
|
+
class PGText(Scoring):
|
23
|
+
"""
|
24
|
+
Postgres full text search (FTS) based scoring.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, config=None):
|
28
|
+
super().__init__(config)
|
29
|
+
|
30
|
+
if not PGTEXT:
|
31
|
+
raise ImportError('PGText is not available - install "scoring" extra to enable')
|
32
|
+
|
33
|
+
# Database connection
|
34
|
+
self.engine, self.database, self.connection, self.table = None, None, None, None
|
35
|
+
|
36
|
+
# Language
|
37
|
+
self.language = self.config.get("language", "english")
|
38
|
+
|
39
|
+
def insert(self, documents, index=None, checkpoint=None):
|
40
|
+
# Initialize tables
|
41
|
+
self.initialize(recreate=True)
|
42
|
+
|
43
|
+
# Collection of rows to insert
|
44
|
+
rows = []
|
45
|
+
|
46
|
+
# Collect rows
|
47
|
+
for uid, document, _ in documents:
|
48
|
+
# Extract text, if necessary
|
49
|
+
if isinstance(document, dict):
|
50
|
+
document = document.get(self.text, document.get(self.object))
|
51
|
+
|
52
|
+
if document is not None:
|
53
|
+
# If index is passed, use indexid, otherwise use id
|
54
|
+
uid = index if index is not None else uid
|
55
|
+
|
56
|
+
# Add row if the data type is accepted
|
57
|
+
if isinstance(document, (str, list)):
|
58
|
+
rows.append((uid, " ".join(document) if isinstance(document, list) else document))
|
59
|
+
|
60
|
+
# Increment index
|
61
|
+
index = index + 1 if index is not None else None
|
62
|
+
|
63
|
+
# Insert rows
|
64
|
+
self.database.execute(self.table.insert(), [{"indexid": x, "text": text} for x, text in rows])
|
65
|
+
|
66
|
+
def delete(self, ids):
|
67
|
+
self.database.execute(delete(self.table).where(self.table.c["indexid"].in_(ids)))
|
68
|
+
|
69
|
+
def weights(self, tokens):
|
70
|
+
# Not supported
|
71
|
+
return None
|
72
|
+
|
73
|
+
def search(self, query, limit=3):
|
74
|
+
# Run query
|
75
|
+
query = (
|
76
|
+
self.database.query(self.table.c["indexid"], text("ts_rank(vector, plainto_tsquery(:language, :query)) rank"))
|
77
|
+
.order_by(desc(text("rank")))
|
78
|
+
.limit(limit)
|
79
|
+
.params({"language": self.language, "query": query})
|
80
|
+
)
|
81
|
+
|
82
|
+
return [(uid, score) for uid, score in query if score > 1e-5]
|
83
|
+
|
84
|
+
def batchsearch(self, queries, limit=3, threads=True):
|
85
|
+
return [self.search(query, limit) for query in queries]
|
86
|
+
|
87
|
+
def count(self):
|
88
|
+
# pylint: disable=E1102
|
89
|
+
return self.database.query(func.count(self.table.c["indexid"])).scalar()
|
90
|
+
|
91
|
+
def load(self, path):
|
92
|
+
# Reset database to original checkpoint
|
93
|
+
if self.database:
|
94
|
+
self.database.rollback()
|
95
|
+
self.connection.rollback()
|
96
|
+
|
97
|
+
# Initialize tables
|
98
|
+
self.initialize()
|
99
|
+
|
100
|
+
def save(self, path):
|
101
|
+
# Commit session and connection
|
102
|
+
if self.database:
|
103
|
+
self.database.commit()
|
104
|
+
self.connection.commit()
|
105
|
+
|
106
|
+
def close(self):
|
107
|
+
if self.database:
|
108
|
+
self.database.close()
|
109
|
+
self.engine.dispose()
|
110
|
+
|
111
|
+
def issparse(self):
|
112
|
+
return True
|
113
|
+
|
114
|
+
def isnormalized(self):
|
115
|
+
return True
|
116
|
+
|
117
|
+
def initialize(self, recreate=False):
|
118
|
+
"""
|
119
|
+
Initializes a new database session.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
recreate: Recreates the database tables if True
|
123
|
+
"""
|
124
|
+
|
125
|
+
if not self.database:
|
126
|
+
# Create engine, connection and session
|
127
|
+
self.engine = create_engine(self.config.get("url", os.environ.get("SCORING_URL")), poolclass=StaticPool, echo=False)
|
128
|
+
self.connection = self.engine.connect()
|
129
|
+
self.database = Session(self.connection)
|
130
|
+
|
131
|
+
# Set default schema, if necessary
|
132
|
+
schema = self.config.get("schema")
|
133
|
+
if schema:
|
134
|
+
with self.engine.begin():
|
135
|
+
self.sqldialect(CreateSchema(schema, if_not_exists=True))
|
136
|
+
|
137
|
+
self.sqldialect(text("SET search_path TO :schema"), {"schema": schema})
|
138
|
+
|
139
|
+
# Table name
|
140
|
+
table = self.config.get("table", "scoring")
|
141
|
+
|
142
|
+
# Create vectors table
|
143
|
+
self.table = Table(
|
144
|
+
table,
|
145
|
+
MetaData(),
|
146
|
+
Column("indexid", Integer, primary_key=True, autoincrement=False),
|
147
|
+
Column("text", Text),
|
148
|
+
(
|
149
|
+
Column("vector", TSVECTOR, Computed(f"to_tsvector('{self.language}', text)", persisted=True))
|
150
|
+
if self.engine.dialect.name == "postgresql"
|
151
|
+
else Column("vector", Integer)
|
152
|
+
),
|
153
|
+
)
|
154
|
+
|
155
|
+
# Create text index
|
156
|
+
index = Index(
|
157
|
+
f"{table}-index",
|
158
|
+
self.table.c["vector"],
|
159
|
+
postgresql_using="gin",
|
160
|
+
)
|
161
|
+
|
162
|
+
# Drop and recreate table
|
163
|
+
if recreate:
|
164
|
+
self.table.drop(self.connection, checkfirst=True)
|
165
|
+
index.drop(self.connection, checkfirst=True)
|
166
|
+
|
167
|
+
# Create table and index
|
168
|
+
self.table.create(self.connection, checkfirst=True)
|
169
|
+
index.create(self.connection, checkfirst=True)
|
170
|
+
|
171
|
+
def sqldialect(self, sql, parameters=None):
|
172
|
+
"""
|
173
|
+
Executes a SQL statement based on the current SQL dialect.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
sql: SQL to execute
|
177
|
+
parameters: optional bind parameters
|
178
|
+
"""
|
179
|
+
|
180
|
+
args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
|
181
|
+
self.database.execute(*args)
|
txtai/scoring/sif.py
ADDED
@@ -0,0 +1,32 @@
|
|
1
|
+
"""
|
2
|
+
SIF module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from .tfidf import TFIDF
|
8
|
+
|
9
|
+
|
10
|
+
class SIF(TFIDF):
|
11
|
+
"""
|
12
|
+
Smooth Inverse Frequency (SIF) scoring.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, config=None):
|
16
|
+
super().__init__(config)
|
17
|
+
|
18
|
+
# SIF configurable parameters
|
19
|
+
self.a = self.config.get("a", 1e-3)
|
20
|
+
|
21
|
+
def computefreq(self, tokens):
|
22
|
+
# Default method computes frequency for a single entry
|
23
|
+
# SIF uses word frequencies across entire index
|
24
|
+
return {token: self.wordfreq[token] for token in tokens}
|
25
|
+
|
26
|
+
def score(self, freq, idf, length):
|
27
|
+
# Set freq to word frequencies across entire index when freq and idf shape don't match
|
28
|
+
if isinstance(freq, np.ndarray) and freq.shape != np.array(idf).shape:
|
29
|
+
freq.fill(freq.sum())
|
30
|
+
|
31
|
+
# Calculate SIF score
|
32
|
+
return self.a / (self.a + freq / self.tokens)
|
txtai/scoring/sparse.py
ADDED
@@ -0,0 +1,218 @@
|
|
1
|
+
"""
|
2
|
+
Sparse module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from queue import Queue
|
6
|
+
from threading import Thread
|
7
|
+
|
8
|
+
from ..ann import SparseANNFactory
|
9
|
+
from ..vectors import SparseVectorsFactory
|
10
|
+
|
11
|
+
from .base import Scoring
|
12
|
+
|
13
|
+
|
14
|
+
class Sparse(Scoring):
|
15
|
+
"""
|
16
|
+
Sparse vector scoring.
|
17
|
+
"""
|
18
|
+
|
19
|
+
# End of stream message
|
20
|
+
COMPLETE = 1
|
21
|
+
|
22
|
+
def __init__(self, config=None, models=None):
|
23
|
+
super().__init__(config)
|
24
|
+
|
25
|
+
# Vector configuration
|
26
|
+
mapping = {"vectormethod": "method", "vectornormalize": "normalize"}
|
27
|
+
config = {k: v for k, v in config.items() if k not in mapping.values()}
|
28
|
+
for k, v in mapping.items():
|
29
|
+
if k in config:
|
30
|
+
config[v] = config[k]
|
31
|
+
|
32
|
+
# Load the SparseVectors model
|
33
|
+
self.model = SparseVectorsFactory.create(config, models)
|
34
|
+
|
35
|
+
# Normalize search outputs if vectors are not normalized already
|
36
|
+
# A float can also be provided to set the normalization factor (defaults to 30.0)
|
37
|
+
self.isnormalize = self.config.get("normalize", True)
|
38
|
+
|
39
|
+
# Sparse ANN
|
40
|
+
self.ann = None
|
41
|
+
|
42
|
+
# Encoding processing parameters
|
43
|
+
self.batch = self.config.get("batch", 1024)
|
44
|
+
self.thread, self.queue, self.data = None, None, None
|
45
|
+
|
46
|
+
def insert(self, documents, index=None, checkpoint=None):
|
47
|
+
# Start processing thread, if necessary
|
48
|
+
self.start(checkpoint)
|
49
|
+
|
50
|
+
data = []
|
51
|
+
for uid, document, tags in documents:
|
52
|
+
# Extract text, if necessary
|
53
|
+
if isinstance(document, dict):
|
54
|
+
document = document.get(self.text, document.get(self.object))
|
55
|
+
|
56
|
+
if document is not None:
|
57
|
+
# Add data
|
58
|
+
data.append((uid, " ".join(document) if isinstance(document, list) else document, tags))
|
59
|
+
|
60
|
+
# Add batch of data
|
61
|
+
self.queue.put(data)
|
62
|
+
|
63
|
+
def delete(self, ids):
|
64
|
+
self.ann.delete(ids)
|
65
|
+
|
66
|
+
def index(self, documents=None):
|
67
|
+
# Insert documents, if provided
|
68
|
+
if documents:
|
69
|
+
self.insert(documents)
|
70
|
+
|
71
|
+
# Create ANN, if there is pending data
|
72
|
+
embeddings = self.stop()
|
73
|
+
if embeddings is not None:
|
74
|
+
self.ann = SparseANNFactory.create(self.config)
|
75
|
+
self.ann.index(embeddings)
|
76
|
+
|
77
|
+
def upsert(self, documents=None):
|
78
|
+
# Insert documents, if provided
|
79
|
+
if documents:
|
80
|
+
self.insert(documents)
|
81
|
+
|
82
|
+
# Check for existing index and pending data
|
83
|
+
if self.ann:
|
84
|
+
embeddings = self.stop()
|
85
|
+
if embeddings is not None:
|
86
|
+
self.ann.append(embeddings)
|
87
|
+
else:
|
88
|
+
self.index()
|
89
|
+
|
90
|
+
def weights(self, tokens):
|
91
|
+
# Not supported
|
92
|
+
return None
|
93
|
+
|
94
|
+
def search(self, query, limit=3):
|
95
|
+
return self.batchsearch([query], limit)[0]
|
96
|
+
|
97
|
+
def batchsearch(self, queries, limit=3, threads=True):
|
98
|
+
# Convert queries to embedding vectors
|
99
|
+
embeddings = self.model.batchtransform((None, query, None) for query in queries)
|
100
|
+
|
101
|
+
# Run ANN search
|
102
|
+
scores = self.ann.search(embeddings, limit)
|
103
|
+
|
104
|
+
# Normalize scores if normalization IS enabled AND vector normalization IS NOT enabled
|
105
|
+
return self.normalize(embeddings, scores) if self.isnormalize and not self.model.isnormalize else scores
|
106
|
+
|
107
|
+
def count(self):
|
108
|
+
return self.ann.count()
|
109
|
+
|
110
|
+
def load(self, path):
|
111
|
+
self.ann = SparseANNFactory.create(self.config)
|
112
|
+
self.ann.load(path)
|
113
|
+
|
114
|
+
def save(self, path):
|
115
|
+
# Save Sparse ANN
|
116
|
+
if self.ann:
|
117
|
+
self.ann.save(path)
|
118
|
+
|
119
|
+
def close(self):
|
120
|
+
# Close Sparse ANN
|
121
|
+
if self.ann:
|
122
|
+
self.ann.close()
|
123
|
+
|
124
|
+
# Clear parameters
|
125
|
+
self.model, self.ann, self.thread, self.queue = None, None, None, None
|
126
|
+
|
127
|
+
def issparse(self):
|
128
|
+
return True
|
129
|
+
|
130
|
+
def isnormalized(self):
|
131
|
+
return self.isnormalize or self.model.isnormalize
|
132
|
+
|
133
|
+
def start(self, checkpoint):
|
134
|
+
"""
|
135
|
+
Starts an encoding processing thread.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
checkpoint: checkpoint directory
|
139
|
+
"""
|
140
|
+
|
141
|
+
if not self.thread:
|
142
|
+
self.queue = Queue(5)
|
143
|
+
self.thread = Thread(target=self.encode, args=(checkpoint,))
|
144
|
+
self.thread.start()
|
145
|
+
|
146
|
+
def stop(self):
|
147
|
+
"""
|
148
|
+
Stops an encoding processing thread. Return processed results.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
results
|
152
|
+
"""
|
153
|
+
|
154
|
+
results = None
|
155
|
+
if self.thread:
|
156
|
+
# Send EOS message
|
157
|
+
self.queue.put(Sparse.COMPLETE)
|
158
|
+
|
159
|
+
self.thread.join()
|
160
|
+
self.thread, self.queue = None, None
|
161
|
+
|
162
|
+
# Get return value
|
163
|
+
results = self.data
|
164
|
+
self.data = None
|
165
|
+
|
166
|
+
return results
|
167
|
+
|
168
|
+
def encode(self, checkpoint):
|
169
|
+
"""
|
170
|
+
Encodes streaming data.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
checkpoint: checkpoint directory
|
174
|
+
"""
|
175
|
+
|
176
|
+
# Streaming encoding of data
|
177
|
+
_, dimensions, self.data = self.model.vectors(self.stream(), self.batch, checkpoint)
|
178
|
+
|
179
|
+
# Save number of dimensions
|
180
|
+
self.config["dimensions"] = dimensions
|
181
|
+
|
182
|
+
def stream(self):
|
183
|
+
"""
|
184
|
+
Streams data from an input queue until end of stream message received.
|
185
|
+
"""
|
186
|
+
|
187
|
+
batch = self.queue.get()
|
188
|
+
while batch != Sparse.COMPLETE:
|
189
|
+
yield from batch
|
190
|
+
batch = self.queue.get()
|
191
|
+
|
192
|
+
def normalize(self, queries, scores):
|
193
|
+
"""
|
194
|
+
Normalize query result using the max query score.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
queries: query vectors
|
198
|
+
scores: query results
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
normalized query results
|
202
|
+
"""
|
203
|
+
|
204
|
+
# Get normalize scale factor
|
205
|
+
scale = 30.0 if isinstance(self.isnormalize, bool) else self.isnormalize
|
206
|
+
|
207
|
+
# Normalize scores using max scores
|
208
|
+
maxscores = self.model.dot(queries, queries)
|
209
|
+
|
210
|
+
# Normalize results and return
|
211
|
+
results = []
|
212
|
+
for x, result in enumerate(scores):
|
213
|
+
maxscore = max(maxscores[x][x] / scale, scale)
|
214
|
+
maxscore = max(maxscore, result[0][1]) if result else maxscore
|
215
|
+
|
216
|
+
results.append([(uid, score / maxscore) for uid, score in result])
|
217
|
+
|
218
|
+
return results
|