mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,121 @@
1
+ """
2
+ Factory module
3
+ """
4
+
5
+ from ...util import Resolver
6
+
7
+ from .external import External
8
+ from .huggingface import HFVectors
9
+ from .litellm import LiteLLM
10
+ from .llama import LlamaCpp
11
+ from .m2v import Model2Vec
12
+ from .sbert import STVectors
13
+ from .words import WordVectors
14
+
15
+
16
+ class VectorsFactory:
17
+ """
18
+ Methods to create dense vector models.
19
+ """
20
+
21
+ @staticmethod
22
+ def create(config, scoring=None, models=None):
23
+ """
24
+ Create a Vectors model instance.
25
+
26
+ Args:
27
+ config: vector configuration
28
+ scoring: scoring instance
29
+ models: models cache
30
+
31
+ Returns:
32
+ Vectors
33
+ """
34
+
35
+ # Determine vector method
36
+ method = VectorsFactory.method(config)
37
+
38
+ # External vectors
39
+ if method == "external":
40
+ return External(config, scoring, models)
41
+
42
+ # LiteLLM vectors
43
+ if method == "litellm":
44
+ return LiteLLM(config, scoring, models)
45
+
46
+ # llama.cpp vectors
47
+ if method == "llama.cpp":
48
+ return LlamaCpp(config, scoring, models)
49
+
50
+ # Model2vec vectors
51
+ if method == "model2vec":
52
+ return Model2Vec(config, scoring, models)
53
+
54
+ # Sentence Transformers vectors
55
+ if method == "sentence-transformers":
56
+ return STVectors(config, scoring, models) if config and config.get("path") else None
57
+
58
+ # Word vectors
59
+ if method == "words":
60
+ return WordVectors(config, scoring, models)
61
+
62
+ # Transformers vectors
63
+ if HFVectors.ismethod(method):
64
+ return HFVectors(config, scoring, models) if config and config.get("path") else None
65
+
66
+ # Resolve custom method
67
+ return VectorsFactory.resolve(method, config, scoring, models) if method else None
68
+
69
+ @staticmethod
70
+ def method(config):
71
+ """
72
+ Get or derive the vector method.
73
+
74
+ Args:
75
+ config: vector configuration
76
+
77
+ Returns:
78
+ vector method
79
+ """
80
+
81
+ # Determine vector method
82
+ method = config.get("method")
83
+ path = config.get("path")
84
+
85
+ # Infer method from path, if blank
86
+ if not method:
87
+ if path:
88
+ if LiteLLM.ismodel(path):
89
+ method = "litellm"
90
+ elif LlamaCpp.ismodel(path):
91
+ method = "llama.cpp"
92
+ elif Model2Vec.ismodel(path):
93
+ method = "model2vec"
94
+ elif WordVectors.ismodel(path):
95
+ method = "words"
96
+ else:
97
+ method = "transformers"
98
+ elif config.get("transform"):
99
+ method = "external"
100
+
101
+ return method
102
+
103
+ @staticmethod
104
+ def resolve(backend, config, scoring, models):
105
+ """
106
+ Attempt to resolve a custom backend.
107
+
108
+ Args:
109
+ backend: backend class
110
+ config: vector configuration
111
+ scoring: scoring instance
112
+ models: models cache
113
+
114
+ Returns:
115
+ Vectors
116
+ """
117
+
118
+ try:
119
+ return Resolver()(backend)(config, scoring, models)
120
+ except Exception as e:
121
+ raise ImportError(f"Unable to resolve vectors backend: '{backend}'") from e
@@ -0,0 +1,44 @@
1
+ """
2
+ Hugging Face module
3
+ """
4
+
5
+ from ...models import Models, PoolingFactory
6
+
7
+ from ..base import Vectors
8
+
9
+
10
+ class HFVectors(Vectors):
11
+ """
12
+ Builds vectors using the Hugging Face transformers library.
13
+ """
14
+
15
+ @staticmethod
16
+ def ismethod(method):
17
+ """
18
+ Checks if this method uses local transformers-based models.
19
+
20
+ Args:
21
+ method: input method
22
+
23
+ Returns:
24
+ True if this is a local transformers-based model, False otherwise
25
+ """
26
+
27
+ return method in ("transformers", "pooling", "clspooling", "meanpooling")
28
+
29
+ def loadmodel(self, path):
30
+ # Build embeddings with transformers pooling
31
+ return PoolingFactory.create(
32
+ {
33
+ "method": self.config.get("method"),
34
+ "path": path,
35
+ "device": Models.deviceid(self.config.get("gpu", True)),
36
+ "tokenizer": self.config.get("tokenizer"),
37
+ "maxlength": self.config.get("maxlength"),
38
+ "modelargs": self.config.get("vectors", {}),
39
+ }
40
+ )
41
+
42
+ def encode(self, data, category=None):
43
+ # Encode data using vectors model
44
+ return self.model.encode(data, batch=self.encodebatch, category=category)
@@ -0,0 +1,86 @@
1
+ """
2
+ LiteLLM module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ from transformers.utils import cached_file
8
+
9
+ # Conditional import
10
+ try:
11
+ import litellm as api
12
+
13
+ LITELLM = True
14
+ except ImportError:
15
+ LITELLM = False
16
+
17
+ from ..base import Vectors
18
+
19
+
20
+ class LiteLLM(Vectors):
21
+ """
22
+ Builds vectors using an external embeddings API via LiteLLM.
23
+ """
24
+
25
+ @staticmethod
26
+ def ismodel(path):
27
+ """
28
+ Checks if path is a LiteLLM model.
29
+
30
+ Args:
31
+ path: input path
32
+
33
+ Returns:
34
+ True if this is a LiteLLM model, False otherwise
35
+ """
36
+
37
+ # pylint: disable=W0702
38
+ if isinstance(path, str) and LITELLM:
39
+ debug = api.suppress_debug_info
40
+ try:
41
+ # Suppress debug messages for this test
42
+ api.suppress_debug_info = True
43
+ return api.get_llm_provider(path) and not LiteLLM.ishub(path)
44
+ except:
45
+ return False
46
+ finally:
47
+ # Restore debug info value to original value
48
+ api.suppress_debug_info = debug
49
+
50
+ return False
51
+
52
+ @staticmethod
53
+ def ishub(path):
54
+ """
55
+ Checks if path is available on the HF Hub.
56
+
57
+ Args:
58
+ input path
59
+
60
+ Returns:
61
+ True if this is a model on the HF Hub
62
+ """
63
+
64
+ # pylint: disable=W0702
65
+ try:
66
+ return cached_file(path_or_repo_id=path, filename="config.json") is not None if "/" in path else False
67
+ except:
68
+ return False
69
+
70
+ def __init__(self, config, scoring, models):
71
+ # Check before parent constructor since it calls loadmodel
72
+ if not LITELLM:
73
+ raise ImportError('LiteLLM is not available - install "vectors" extra to enable')
74
+
75
+ super().__init__(config, scoring, models)
76
+
77
+ def loadmodel(self, path):
78
+ return None
79
+
80
+ def encode(self, data, category=None):
81
+ # Call external embeddings API using LiteLLM
82
+ # Batching is handled server-side
83
+ response = api.embedding(model=self.config.get("path"), input=data, **self.config.get("vectors", {}))
84
+
85
+ # Read response into a NumPy array
86
+ return np.array([x["embedding"] for x in response.data], dtype=np.float32)
@@ -0,0 +1,84 @@
1
+ """
2
+ Llama module
3
+ """
4
+
5
+ import os
6
+
7
+ import numpy as np
8
+
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ # Conditional import
12
+ try:
13
+ from llama_cpp import Llama
14
+
15
+ LLAMA_CPP = True
16
+ except ImportError:
17
+ LLAMA_CPP = False
18
+
19
+ from ..base import Vectors
20
+
21
+
22
+ class LlamaCpp(Vectors):
23
+ """
24
+ Builds vectors using llama.cpp.
25
+ """
26
+
27
+ @staticmethod
28
+ def ismodel(path):
29
+ """
30
+ Checks if path is a llama.cpp model.
31
+
32
+ Args:
33
+ path: input path
34
+
35
+ Returns:
36
+ True if this is a llama.cpp model, False otherwise
37
+ """
38
+
39
+ return isinstance(path, str) and path.lower().endswith(".gguf")
40
+
41
+ def __init__(self, config, scoring, models):
42
+ # Check before parent constructor since it calls loadmodel
43
+ if not LLAMA_CPP:
44
+ raise ImportError('llama.cpp is not available - install "vectors" extra to enable')
45
+
46
+ super().__init__(config, scoring, models)
47
+
48
+ def loadmodel(self, path):
49
+ # Check if this is a local path, otherwise download from the HF Hub
50
+ path = path if os.path.exists(path) else self.download(path)
51
+
52
+ # Additional model arguments
53
+ modelargs = self.config.get("vectors", {})
54
+
55
+ # Default GPU layers if not already set
56
+ modelargs["n_gpu_layers"] = modelargs.get("n_gpu_layers", -1 if self.config.get("gpu", os.environ.get("LLAMA_NO_METAL") != "1") else 0)
57
+
58
+ # Create llama.cpp instance
59
+ return Llama(path, n_ctx=0, verbose=modelargs.pop("verbose", False), embedding=True, **modelargs)
60
+
61
+ def encode(self, data, category=None):
62
+ # Generate embeddings and return as a NumPy array
63
+ # llama-cpp-python has it's own batching built-in using n_batch parameter
64
+ return np.array(self.model.embed(data), dtype=np.float32)
65
+
66
+ def download(self, path):
67
+ """
68
+ Downloads path from the Hugging Face Hub.
69
+
70
+ Args:
71
+ path: full model path
72
+
73
+ Returns:
74
+ local cached model path
75
+ """
76
+
77
+ # Split into parts
78
+ parts = path.split("/")
79
+
80
+ # Calculate repo id split
81
+ repo = 2 if len(parts) > 2 else 1
82
+
83
+ # Download and cache file
84
+ return hf_hub_download(repo_id="/".join(parts[:repo]), filename="/".join(parts[repo:]))
@@ -0,0 +1,67 @@
1
+ """
2
+ Model2Vec module
3
+ """
4
+
5
+ import json
6
+
7
+ from huggingface_hub.errors import HFValidationError
8
+ from transformers.utils import cached_file
9
+
10
+ # Conditional import
11
+ try:
12
+ from model2vec import StaticModel
13
+
14
+ MODEL2VEC = True
15
+ except ImportError:
16
+ MODEL2VEC = False
17
+
18
+ from ..base import Vectors
19
+
20
+
21
+ class Model2Vec(Vectors):
22
+ """
23
+ Builds vectors using Model2Vec.
24
+ """
25
+
26
+ @staticmethod
27
+ def ismodel(path):
28
+ """
29
+ Checks if path is a Model2Vec model.
30
+
31
+ Args:
32
+ path: input path
33
+
34
+ Returns:
35
+ True if this is a Model2Vec model, False otherwise
36
+ """
37
+
38
+ try:
39
+ # Download file and parse JSON
40
+ path = cached_file(path_or_repo_id=path, filename="config.json")
41
+ if path:
42
+ with open(path, encoding="utf-8") as f:
43
+ config = json.load(f)
44
+ return config.get("model_type") == "model2vec"
45
+
46
+ # Ignore this error - invalid repo or directory
47
+ except (HFValidationError, OSError):
48
+ pass
49
+
50
+ return False
51
+
52
+ def __init__(self, config, scoring, models):
53
+ # Check before parent constructor since it calls loadmodel
54
+ if not MODEL2VEC:
55
+ raise ImportError('Model2Vec is not available - install "vectors" extra to enable')
56
+
57
+ super().__init__(config, scoring, models)
58
+
59
+ def loadmodel(self, path):
60
+ return StaticModel.from_pretrained(path)
61
+
62
+ def encode(self, data, category=None):
63
+ # Additional model arguments
64
+ modelargs = self.config.get("vectors", {})
65
+
66
+ # Encode data
67
+ return self.model.encode(data, batch_size=self.encodebatch, **modelargs)
@@ -0,0 +1,92 @@
1
+ """
2
+ Sentence Transformers module
3
+ """
4
+
5
+ # Conditional import
6
+ try:
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ SENTENCE_TRANSFORMERS = True
10
+ except ImportError:
11
+ SENTENCE_TRANSFORMERS = False
12
+
13
+ from ...models import Models
14
+
15
+ from ..base import Vectors
16
+
17
+
18
+ class STVectors(Vectors):
19
+ """
20
+ Builds vectors using sentence-transformers (aka SBERT).
21
+ """
22
+
23
+ def __init__(self, config, scoring, models):
24
+ # Check before parent constructor since it calls loadmodel
25
+ if not SENTENCE_TRANSFORMERS:
26
+ raise ImportError('sentence-transformers is not available - install "vectors" extra to enable')
27
+
28
+ # Pool parameter created here since loadmodel is called from parent constructor
29
+ self.pool = None
30
+
31
+ super().__init__(config, scoring, models)
32
+
33
+ def loadmodel(self, path):
34
+ # Get target device
35
+ gpu, pool = self.config.get("gpu", True), False
36
+
37
+ # Default mode uses a single GPU. Setting to all spawns a process per GPU.
38
+ if isinstance(gpu, str) and gpu == "all":
39
+ # Get number of accelerator devices available
40
+ devices = Models.acceleratorcount()
41
+
42
+ # Enable multiprocessing pooling only when multiple devices are available
43
+ gpu, pool = devices <= 1, devices > 1
44
+
45
+ # Tensor device id
46
+ deviceid = Models.deviceid(gpu)
47
+
48
+ # Additional model arguments
49
+ modelargs = self.config.get("vectors", {})
50
+
51
+ # Load sentence-transformers encoder
52
+ model = self.loadencoder(path, device=Models.device(deviceid), **modelargs)
53
+
54
+ # Start process pool for multiple GPUs
55
+ if pool:
56
+ self.pool = model.start_multi_process_pool()
57
+
58
+ # Return model
59
+ return model
60
+
61
+ def encode(self, data, category=None):
62
+ # Get encode method based on input category
63
+ encode = self.model.encode_query if category == "query" else self.model.encode_document if category == "data" else self.model.encode
64
+
65
+ # Additional encoding arguments
66
+ encodeargs = self.config.get("encodeargs", {})
67
+
68
+ # Encode with sentence transformers encoder
69
+ return encode(data, pool=self.pool, batch_size=self.encodebatch, **encodeargs)
70
+
71
+ def close(self):
72
+ # Close pool before model is closed in parent method
73
+ if self.pool:
74
+ self.model.stop_multi_process_pool(self.pool)
75
+ self.pool = None
76
+
77
+ super().close()
78
+
79
+ def loadencoder(self, path, device, **kwargs):
80
+ """
81
+ Loads the embeddings encoder model from path.
82
+
83
+ Args:
84
+ path: model path
85
+ device: tensor device
86
+ kwargs: additional keyword args
87
+
88
+ Returns:
89
+ embeddings encoder
90
+ """
91
+
92
+ return SentenceTransformer(path, device=device, **kwargs)
@@ -0,0 +1,211 @@
1
+ """
2
+ Word Vectors module
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ import tempfile
9
+
10
+ from multiprocessing import Pool
11
+
12
+ import numpy as np
13
+
14
+ from huggingface_hub.errors import HFValidationError
15
+ from transformers.utils import cached_file
16
+
17
+ # Conditional import
18
+ try:
19
+ from staticvectors import Database, StaticVectors
20
+
21
+ STATICVECTORS = True
22
+ except ImportError:
23
+ STATICVECTORS = False
24
+
25
+ from ...pipeline import Tokenizer
26
+
27
+ from ..base import Vectors
28
+
29
+ # Logging configuration
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Multiprocessing helper methods
33
+ # pylint: disable=W0603
34
+ PARAMETERS, VECTORS = None, None
35
+
36
+
37
+ def create(config, scoring):
38
+ """
39
+ Multiprocessing helper method. Creates a global embeddings object to be accessed in a new subprocess.
40
+
41
+ Args:
42
+ config: vector configuration
43
+ scoring: scoring instance
44
+ """
45
+
46
+ global PARAMETERS
47
+ global VECTORS
48
+
49
+ # Store model parameters for lazy loading
50
+ PARAMETERS, VECTORS = (config, scoring, None), None
51
+
52
+
53
+ def transform(document):
54
+ """
55
+ Multiprocessing helper method. Transforms document into an embeddings vector.
56
+
57
+ Args:
58
+ document: (id, data, tags)
59
+
60
+ Returns:
61
+ (id, embedding)
62
+ """
63
+
64
+ # Lazy load vectors model
65
+ global VECTORS
66
+ if not VECTORS:
67
+ VECTORS = WordVectors(*PARAMETERS)
68
+
69
+ return (document[0], VECTORS.transform(document))
70
+
71
+
72
+ class WordVectors(Vectors):
73
+ """
74
+ Builds vectors using weighted word embeddings.
75
+ """
76
+
77
+ @staticmethod
78
+ def ismodel(path):
79
+ """
80
+ Checks if path is a WordVectors model.
81
+
82
+ Args:
83
+ path: input path
84
+
85
+ Returns:
86
+ True if this is a WordVectors model, False otherwise
87
+ """
88
+
89
+ # Check if this is a SQLite database
90
+ if WordVectors.isdatabase(path):
91
+ return True
92
+
93
+ try:
94
+ # Download file and parse JSON
95
+ path = cached_file(path_or_repo_id=path, filename="config.json")
96
+ if path:
97
+ with open(path, encoding="utf-8") as f:
98
+ config = json.load(f)
99
+ return config.get("model_type") == "staticvectors"
100
+
101
+ # Ignore this error - invalid repo or directory
102
+ except (HFValidationError, OSError):
103
+ pass
104
+
105
+ return False
106
+
107
+ @staticmethod
108
+ def isdatabase(path):
109
+ """
110
+ Checks if this is a SQLite database file which is the file format used for word vectors databases.
111
+
112
+ Args:
113
+ path: path to check
114
+
115
+ Returns:
116
+ True if this is a SQLite database
117
+ """
118
+
119
+ return isinstance(path, str) and STATICVECTORS and Database.isdatabase(path)
120
+
121
+ def __init__(self, config, scoring, models):
122
+ # Check before parent constructor since it calls loadmodel
123
+ if not STATICVECTORS:
124
+ raise ImportError('staticvectors is not available - install "vectors" extra to enable')
125
+
126
+ super().__init__(config, scoring, models)
127
+
128
+ def loadmodel(self, path):
129
+ return StaticVectors(path)
130
+
131
+ def encode(self, data, category=None):
132
+ # Iterate over each data element, tokenize (if necessary) and build an aggregated embeddings vector
133
+ embeddings = []
134
+ for tokens in data:
135
+ # Convert to tokens, if necessary. If tokenized list is empty, use input string.
136
+ if isinstance(tokens, str):
137
+ tokenlist = Tokenizer.tokenize(tokens)
138
+ tokens = tokenlist if tokenlist else [tokens]
139
+
140
+ # Generate weights for each vector using a scoring method
141
+ weights = self.scoring.weights(tokens) if self.scoring else None
142
+
143
+ # pylint: disable=E1133
144
+ if weights and [x for x in weights if x > 0]:
145
+ # Build weighted average embeddings vector. Create weights array as float32 to match embeddings precision.
146
+ embedding = np.average(self.lookup(tokens), weights=np.array(weights, dtype=np.float32), axis=0)
147
+ else:
148
+ # If no weights, use mean
149
+ embedding = np.mean(self.lookup(tokens), axis=0)
150
+
151
+ embeddings.append(embedding)
152
+
153
+ return np.array(embeddings, dtype=np.float32)
154
+
155
+ def index(self, documents, batchsize=500, checkpoint=None):
156
+ # Derive number of parallel processes
157
+ parallel = self.config.get("parallel", True)
158
+ parallel = os.cpu_count() if parallel and isinstance(parallel, bool) else int(parallel)
159
+
160
+ # Use default single process indexing logic
161
+ if not parallel:
162
+ return super().index(documents, batchsize)
163
+
164
+ # Customize indexing logic with multiprocessing pool to efficiently build vectors
165
+ ids, dimensions, batches, stream = [], None, 0, None
166
+
167
+ # Shared objects with Pool
168
+ args = (self.config, self.scoring)
169
+
170
+ # Convert all documents to embedding arrays, stream embeddings to disk to control memory usage
171
+ with Pool(parallel, initializer=create, initargs=args) as pool:
172
+ with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy", delete=False) as output:
173
+ stream = output.name
174
+ embeddings = []
175
+ for uid, embedding in pool.imap(transform, documents, self.encodebatch):
176
+ if not dimensions:
177
+ # Set number of dimensions for embeddings
178
+ dimensions = embedding.shape[0]
179
+
180
+ ids.append(uid)
181
+ embeddings.append(embedding)
182
+
183
+ if len(embeddings) == batchsize:
184
+ np.save(output, np.array(embeddings, dtype=np.float32), allow_pickle=False)
185
+ batches += 1
186
+
187
+ embeddings = []
188
+
189
+ # Final embeddings batch
190
+ if embeddings:
191
+ np.save(output, np.array(embeddings, dtype=np.float32), allow_pickle=False)
192
+ batches += 1
193
+
194
+ return (ids, dimensions, batches, stream)
195
+
196
+ def lookup(self, tokens):
197
+ """
198
+ Queries word vectors for given list of input tokens.
199
+
200
+ Args:
201
+ tokens: list of tokens to query
202
+
203
+ Returns:
204
+ word vectors array
205
+ """
206
+
207
+ return self.model.embeddings(tokens)
208
+
209
+ def tokens(self, data):
210
+ # Skip tokenization rules
211
+ return data