mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,95 @@
1
+ """
2
+ Factory module
3
+ """
4
+
5
+ from ..util import Resolver
6
+
7
+ from .bm25 import BM25
8
+ from .pgtext import PGText
9
+ from .sif import SIF
10
+ from .sparse import Sparse
11
+ from .tfidf import TFIDF
12
+
13
+
14
+ class ScoringFactory:
15
+ """
16
+ Methods to create Scoring indexes.
17
+ """
18
+
19
+ @staticmethod
20
+ def create(config, models=None):
21
+ """
22
+ Factory method to construct a Scoring instance.
23
+
24
+ Args:
25
+ config: scoring configuration parameters
26
+ models: models cache
27
+
28
+ Returns:
29
+ Scoring
30
+ """
31
+
32
+ # Scoring instance
33
+ scoring = None
34
+
35
+ # Support string and dict configuration
36
+ if isinstance(config, str):
37
+ config = {"method": config}
38
+
39
+ # Get scoring method
40
+ method = config.get("method", "bm25")
41
+
42
+ if method == "bm25":
43
+ scoring = BM25(config)
44
+ elif method == "pgtext":
45
+ scoring = PGText(config)
46
+ elif method == "sif":
47
+ scoring = SIF(config)
48
+ elif method == "sparse":
49
+ scoring = Sparse(config, models)
50
+ elif method == "tfidf":
51
+ scoring = TFIDF(config)
52
+ else:
53
+ # Resolve custom method
54
+ scoring = ScoringFactory.resolve(method, config)
55
+
56
+ # Store config back
57
+ config["method"] = method
58
+
59
+ return scoring
60
+
61
+ @staticmethod
62
+ def issparse(config):
63
+ """
64
+ Checks if this scoring configuration builds a sparse index.
65
+
66
+ Args:
67
+ config: scoring configuration
68
+
69
+ Returns:
70
+ True if this config is for a sparse index
71
+ """
72
+
73
+ # Types that are always a sparse index
74
+ indexes = ["pgtext", "sparse"]
75
+
76
+ # True if this config is for a sparse index
77
+ return config and isinstance(config, dict) and (config.get("method") in indexes or config.get("terms"))
78
+
79
+ @staticmethod
80
+ def resolve(backend, config):
81
+ """
82
+ Attempt to resolve a custom backend.
83
+
84
+ Args:
85
+ backend: backend class
86
+ config: index configuration parameters
87
+
88
+ Returns:
89
+ Scoring
90
+ """
91
+
92
+ try:
93
+ return Resolver()(backend)(config)
94
+ except Exception as e:
95
+ raise ImportError(f"Unable to resolve scoring backend: '{backend}'") from e
@@ -0,0 +1,181 @@
1
+ """
2
+ PGText module
3
+ """
4
+
5
+ import os
6
+
7
+ # Conditional import
8
+ try:
9
+ from sqlalchemy import create_engine, desc, delete, func, text
10
+ from sqlalchemy import Column, Computed, Index, Integer, MetaData, StaticPool, Table, Text
11
+ from sqlalchemy.dialects.postgresql import TSVECTOR
12
+ from sqlalchemy.orm import Session
13
+ from sqlalchemy.schema import CreateSchema
14
+
15
+ PGTEXT = True
16
+ except ImportError:
17
+ PGTEXT = False
18
+
19
+ from .base import Scoring
20
+
21
+
22
+ class PGText(Scoring):
23
+ """
24
+ Postgres full text search (FTS) based scoring.
25
+ """
26
+
27
+ def __init__(self, config=None):
28
+ super().__init__(config)
29
+
30
+ if not PGTEXT:
31
+ raise ImportError('PGText is not available - install "scoring" extra to enable')
32
+
33
+ # Database connection
34
+ self.engine, self.database, self.connection, self.table = None, None, None, None
35
+
36
+ # Language
37
+ self.language = self.config.get("language", "english")
38
+
39
+ def insert(self, documents, index=None, checkpoint=None):
40
+ # Initialize tables
41
+ self.initialize(recreate=True)
42
+
43
+ # Collection of rows to insert
44
+ rows = []
45
+
46
+ # Collect rows
47
+ for uid, document, _ in documents:
48
+ # Extract text, if necessary
49
+ if isinstance(document, dict):
50
+ document = document.get(self.text, document.get(self.object))
51
+
52
+ if document is not None:
53
+ # If index is passed, use indexid, otherwise use id
54
+ uid = index if index is not None else uid
55
+
56
+ # Add row if the data type is accepted
57
+ if isinstance(document, (str, list)):
58
+ rows.append((uid, " ".join(document) if isinstance(document, list) else document))
59
+
60
+ # Increment index
61
+ index = index + 1 if index is not None else None
62
+
63
+ # Insert rows
64
+ self.database.execute(self.table.insert(), [{"indexid": x, "text": text} for x, text in rows])
65
+
66
+ def delete(self, ids):
67
+ self.database.execute(delete(self.table).where(self.table.c["indexid"].in_(ids)))
68
+
69
+ def weights(self, tokens):
70
+ # Not supported
71
+ return None
72
+
73
+ def search(self, query, limit=3):
74
+ # Run query
75
+ query = (
76
+ self.database.query(self.table.c["indexid"], text("ts_rank(vector, plainto_tsquery(:language, :query)) rank"))
77
+ .order_by(desc(text("rank")))
78
+ .limit(limit)
79
+ .params({"language": self.language, "query": query})
80
+ )
81
+
82
+ return [(uid, score) for uid, score in query if score > 1e-5]
83
+
84
+ def batchsearch(self, queries, limit=3, threads=True):
85
+ return [self.search(query, limit) for query in queries]
86
+
87
+ def count(self):
88
+ # pylint: disable=E1102
89
+ return self.database.query(func.count(self.table.c["indexid"])).scalar()
90
+
91
+ def load(self, path):
92
+ # Reset database to original checkpoint
93
+ if self.database:
94
+ self.database.rollback()
95
+ self.connection.rollback()
96
+
97
+ # Initialize tables
98
+ self.initialize()
99
+
100
+ def save(self, path):
101
+ # Commit session and connection
102
+ if self.database:
103
+ self.database.commit()
104
+ self.connection.commit()
105
+
106
+ def close(self):
107
+ if self.database:
108
+ self.database.close()
109
+ self.engine.dispose()
110
+
111
+ def issparse(self):
112
+ return True
113
+
114
+ def isnormalized(self):
115
+ return True
116
+
117
+ def initialize(self, recreate=False):
118
+ """
119
+ Initializes a new database session.
120
+
121
+ Args:
122
+ recreate: Recreates the database tables if True
123
+ """
124
+
125
+ if not self.database:
126
+ # Create engine, connection and session
127
+ self.engine = create_engine(self.config.get("url", os.environ.get("SCORING_URL")), poolclass=StaticPool, echo=False)
128
+ self.connection = self.engine.connect()
129
+ self.database = Session(self.connection)
130
+
131
+ # Set default schema, if necessary
132
+ schema = self.config.get("schema")
133
+ if schema:
134
+ with self.engine.begin():
135
+ self.sqldialect(CreateSchema(schema, if_not_exists=True))
136
+
137
+ self.sqldialect(text("SET search_path TO :schema"), {"schema": schema})
138
+
139
+ # Table name
140
+ table = self.config.get("table", "scoring")
141
+
142
+ # Create vectors table
143
+ self.table = Table(
144
+ table,
145
+ MetaData(),
146
+ Column("indexid", Integer, primary_key=True, autoincrement=False),
147
+ Column("text", Text),
148
+ (
149
+ Column("vector", TSVECTOR, Computed(f"to_tsvector('{self.language}', text)", persisted=True))
150
+ if self.engine.dialect.name == "postgresql"
151
+ else Column("vector", Integer)
152
+ ),
153
+ )
154
+
155
+ # Create text index
156
+ index = Index(
157
+ f"{table}-index",
158
+ self.table.c["vector"],
159
+ postgresql_using="gin",
160
+ )
161
+
162
+ # Drop and recreate table
163
+ if recreate:
164
+ self.table.drop(self.connection, checkfirst=True)
165
+ index.drop(self.connection, checkfirst=True)
166
+
167
+ # Create table and index
168
+ self.table.create(self.connection, checkfirst=True)
169
+ index.create(self.connection, checkfirst=True)
170
+
171
+ def sqldialect(self, sql, parameters=None):
172
+ """
173
+ Executes a SQL statement based on the current SQL dialect.
174
+
175
+ Args:
176
+ sql: SQL to execute
177
+ parameters: optional bind parameters
178
+ """
179
+
180
+ args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
181
+ self.database.execute(*args)
txtai/scoring/sif.py ADDED
@@ -0,0 +1,32 @@
1
+ """
2
+ SIF module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ from .tfidf import TFIDF
8
+
9
+
10
+ class SIF(TFIDF):
11
+ """
12
+ Smooth Inverse Frequency (SIF) scoring.
13
+ """
14
+
15
+ def __init__(self, config=None):
16
+ super().__init__(config)
17
+
18
+ # SIF configurable parameters
19
+ self.a = self.config.get("a", 1e-3)
20
+
21
+ def computefreq(self, tokens):
22
+ # Default method computes frequency for a single entry
23
+ # SIF uses word frequencies across entire index
24
+ return {token: self.wordfreq[token] for token in tokens}
25
+
26
+ def score(self, freq, idf, length):
27
+ # Set freq to word frequencies across entire index when freq and idf shape don't match
28
+ if isinstance(freq, np.ndarray) and freq.shape != np.array(idf).shape:
29
+ freq.fill(freq.sum())
30
+
31
+ # Calculate SIF score
32
+ return self.a / (self.a + freq / self.tokens)
@@ -0,0 +1,218 @@
1
+ """
2
+ Sparse module
3
+ """
4
+
5
+ from queue import Queue
6
+ from threading import Thread
7
+
8
+ from ..ann import SparseANNFactory
9
+ from ..vectors import SparseVectorsFactory
10
+
11
+ from .base import Scoring
12
+
13
+
14
+ class Sparse(Scoring):
15
+ """
16
+ Sparse vector scoring.
17
+ """
18
+
19
+ # End of stream message
20
+ COMPLETE = 1
21
+
22
+ def __init__(self, config=None, models=None):
23
+ super().__init__(config)
24
+
25
+ # Vector configuration
26
+ mapping = {"vectormethod": "method", "vectornormalize": "normalize"}
27
+ config = {k: v for k, v in config.items() if k not in mapping.values()}
28
+ for k, v in mapping.items():
29
+ if k in config:
30
+ config[v] = config[k]
31
+
32
+ # Load the SparseVectors model
33
+ self.model = SparseVectorsFactory.create(config, models)
34
+
35
+ # Normalize search outputs if vectors are not normalized already
36
+ # A float can also be provided to set the normalization factor (defaults to 30.0)
37
+ self.isnormalize = self.config.get("normalize", True)
38
+
39
+ # Sparse ANN
40
+ self.ann = None
41
+
42
+ # Encoding processing parameters
43
+ self.batch = self.config.get("batch", 1024)
44
+ self.thread, self.queue, self.data = None, None, None
45
+
46
+ def insert(self, documents, index=None, checkpoint=None):
47
+ # Start processing thread, if necessary
48
+ self.start(checkpoint)
49
+
50
+ data = []
51
+ for uid, document, tags in documents:
52
+ # Extract text, if necessary
53
+ if isinstance(document, dict):
54
+ document = document.get(self.text, document.get(self.object))
55
+
56
+ if document is not None:
57
+ # Add data
58
+ data.append((uid, " ".join(document) if isinstance(document, list) else document, tags))
59
+
60
+ # Add batch of data
61
+ self.queue.put(data)
62
+
63
+ def delete(self, ids):
64
+ self.ann.delete(ids)
65
+
66
+ def index(self, documents=None):
67
+ # Insert documents, if provided
68
+ if documents:
69
+ self.insert(documents)
70
+
71
+ # Create ANN, if there is pending data
72
+ embeddings = self.stop()
73
+ if embeddings is not None:
74
+ self.ann = SparseANNFactory.create(self.config)
75
+ self.ann.index(embeddings)
76
+
77
+ def upsert(self, documents=None):
78
+ # Insert documents, if provided
79
+ if documents:
80
+ self.insert(documents)
81
+
82
+ # Check for existing index and pending data
83
+ if self.ann:
84
+ embeddings = self.stop()
85
+ if embeddings is not None:
86
+ self.ann.append(embeddings)
87
+ else:
88
+ self.index()
89
+
90
+ def weights(self, tokens):
91
+ # Not supported
92
+ return None
93
+
94
+ def search(self, query, limit=3):
95
+ return self.batchsearch([query], limit)[0]
96
+
97
+ def batchsearch(self, queries, limit=3, threads=True):
98
+ # Convert queries to embedding vectors
99
+ embeddings = self.model.batchtransform((None, query, None) for query in queries)
100
+
101
+ # Run ANN search
102
+ scores = self.ann.search(embeddings, limit)
103
+
104
+ # Normalize scores if normalization IS enabled AND vector normalization IS NOT enabled
105
+ return self.normalize(embeddings, scores) if self.isnormalize and not self.model.isnormalize else scores
106
+
107
+ def count(self):
108
+ return self.ann.count()
109
+
110
+ def load(self, path):
111
+ self.ann = SparseANNFactory.create(self.config)
112
+ self.ann.load(path)
113
+
114
+ def save(self, path):
115
+ # Save Sparse ANN
116
+ if self.ann:
117
+ self.ann.save(path)
118
+
119
+ def close(self):
120
+ # Close Sparse ANN
121
+ if self.ann:
122
+ self.ann.close()
123
+
124
+ # Clear parameters
125
+ self.model, self.ann, self.thread, self.queue = None, None, None, None
126
+
127
+ def issparse(self):
128
+ return True
129
+
130
+ def isnormalized(self):
131
+ return self.isnormalize or self.model.isnormalize
132
+
133
+ def start(self, checkpoint):
134
+ """
135
+ Starts an encoding processing thread.
136
+
137
+ Args:
138
+ checkpoint: checkpoint directory
139
+ """
140
+
141
+ if not self.thread:
142
+ self.queue = Queue(5)
143
+ self.thread = Thread(target=self.encode, args=(checkpoint,))
144
+ self.thread.start()
145
+
146
+ def stop(self):
147
+ """
148
+ Stops an encoding processing thread. Return processed results.
149
+
150
+ Returns:
151
+ results
152
+ """
153
+
154
+ results = None
155
+ if self.thread:
156
+ # Send EOS message
157
+ self.queue.put(Sparse.COMPLETE)
158
+
159
+ self.thread.join()
160
+ self.thread, self.queue = None, None
161
+
162
+ # Get return value
163
+ results = self.data
164
+ self.data = None
165
+
166
+ return results
167
+
168
+ def encode(self, checkpoint):
169
+ """
170
+ Encodes streaming data.
171
+
172
+ Args:
173
+ checkpoint: checkpoint directory
174
+ """
175
+
176
+ # Streaming encoding of data
177
+ _, dimensions, self.data = self.model.vectors(self.stream(), self.batch, checkpoint)
178
+
179
+ # Save number of dimensions
180
+ self.config["dimensions"] = dimensions
181
+
182
+ def stream(self):
183
+ """
184
+ Streams data from an input queue until end of stream message received.
185
+ """
186
+
187
+ batch = self.queue.get()
188
+ while batch != Sparse.COMPLETE:
189
+ yield from batch
190
+ batch = self.queue.get()
191
+
192
+ def normalize(self, queries, scores):
193
+ """
194
+ Normalize query result using the max query score.
195
+
196
+ Args:
197
+ queries: query vectors
198
+ scores: query results
199
+
200
+ Returns:
201
+ normalized query results
202
+ """
203
+
204
+ # Get normalize scale factor
205
+ scale = 30.0 if isinstance(self.isnormalize, bool) else self.isnormalize
206
+
207
+ # Normalize scores using max scores
208
+ maxscores = self.model.dot(queries, queries)
209
+
210
+ # Normalize results and return
211
+ results = []
212
+ for x, result in enumerate(scores):
213
+ maxscore = max(maxscores[x][x] / scale, scale)
214
+ maxscore = max(maxscore, result[0][1]) if result else maxscore
215
+
216
+ results.append([(uid, score / maxscore) for uid, score in result])
217
+
218
+ return results