mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,141 @@
1
+ """
2
+ Pooling module
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from torch import nn
9
+
10
+ from ..models import Models
11
+
12
+
13
+ class Pooling(nn.Module):
14
+ """
15
+ Builds pooled vectors usings outputs from a transformers model.
16
+ """
17
+
18
+ def __init__(self, path, device, tokenizer=None, maxlength=None, modelargs=None):
19
+ """
20
+ Creates a new Pooling model.
21
+
22
+ Args:
23
+ path: path to model, accepts Hugging Face model hub id or local path
24
+ device: tensor device id
25
+ tokenizer: optional path to tokenizer
26
+ maxlength: max sequence length
27
+ modelargs: additional model arguments
28
+ """
29
+
30
+ super().__init__()
31
+
32
+ self.model = Models.load(path, modelargs=modelargs)
33
+ self.tokenizer = Models.tokenizer(tokenizer if tokenizer else path)
34
+ self.device = Models.device(device)
35
+
36
+ # Detect unbounded tokenizer typically found in older models
37
+ Models.checklength(self.model, self.tokenizer)
38
+
39
+ # Set max length
40
+ self.maxlength = maxlength if maxlength else self.tokenizer.model_max_length if self.tokenizer.model_max_length != int(1e30) else None
41
+
42
+ # Move to device
43
+ self.to(self.device)
44
+
45
+ def encode(self, documents, batch=32, category=None):
46
+ """
47
+ Builds an array of pooled embeddings for documents.
48
+
49
+ Args:
50
+ documents: list of documents used to build embeddings
51
+ batch: model batch size
52
+ category: embeddings category (query or data)
53
+
54
+ Returns:
55
+ pooled embeddings
56
+ """
57
+
58
+ # Split documents into batches and process
59
+ results = []
60
+
61
+ # Apply pre encoding transformation logic
62
+ documents = self.preencode(documents, category)
63
+
64
+ # Sort document indices from largest to smallest to enable efficient batching
65
+ # This performance tweak matches logic in sentence-transformers
66
+ lengths = np.argsort([-len(x) if x else 0 for x in documents])
67
+ documents = [documents[x] for x in lengths]
68
+
69
+ for chunk in self.chunk(documents, batch):
70
+ # Tokenize input
71
+ inputs = self.tokenizer(chunk, padding=True, truncation="longest_first", return_tensors="pt", max_length=self.maxlength)
72
+
73
+ # Move inputs to device
74
+ inputs = inputs.to(self.device)
75
+
76
+ # Run inputs through model
77
+ with torch.no_grad():
78
+ outputs = self.forward(**inputs)
79
+
80
+ # Add batch result
81
+ results.extend(outputs.cpu().to(torch.float32).numpy())
82
+
83
+ # Apply post encoding transformation logic
84
+ results = self.postencode(results, category)
85
+
86
+ # Restore original order and return array
87
+ return np.asarray([results[x] for x in np.argsort(lengths)])
88
+
89
+ def chunk(self, texts, size):
90
+ """
91
+ Splits texts into separate batch sizes specified by size.
92
+
93
+ Args:
94
+ texts: text elements
95
+ size: batch size
96
+
97
+ Returns:
98
+ list of evenly sized batches with the last batch having the remaining elements
99
+ """
100
+
101
+ return [texts[x : x + size] for x in range(0, len(texts), size)]
102
+
103
+ def forward(self, **inputs):
104
+ """
105
+ Runs inputs through transformers model and returns outputs.
106
+
107
+ Args:
108
+ inputs: model inputs
109
+
110
+ Returns:
111
+ model outputs
112
+ """
113
+
114
+ return self.model(**inputs)[0]
115
+
116
+ # pylint: disable=W0613
117
+ def preencode(self, documents, category):
118
+ """
119
+ Applies pre encoding transformation logic.
120
+
121
+ Args:
122
+ documents: list of documents used to build embeddings
123
+ category: embeddings category (query or data)
124
+ """
125
+
126
+ return documents
127
+
128
+ # pylint: disable=W0613
129
+ def postencode(self, results, category):
130
+ """
131
+ Applies post encoding transformation logic.
132
+
133
+ Args:
134
+ results: list of results
135
+ category: embeddings category (query or data)
136
+
137
+ Returns:
138
+ results with transformation logic applied
139
+ """
140
+
141
+ return results
@@ -0,0 +1,28 @@
1
+ """
2
+ CLS module
3
+ """
4
+
5
+ from .base import Pooling
6
+
7
+
8
+ class ClsPooling(Pooling):
9
+ """
10
+ Builds CLS pooled vectors using outputs from a transformers model.
11
+ """
12
+
13
+ def forward(self, **inputs):
14
+ """
15
+ Runs CLS pooling on token embeddings.
16
+
17
+ Args:
18
+ inputs: model inputs
19
+
20
+ Returns:
21
+ CLS pooled embeddings using output token embeddings (i.e. last hidden state)
22
+ """
23
+
24
+ # Run through transformers model
25
+ tokens = super().forward(**inputs)
26
+
27
+ # CLS token pooling
28
+ return tokens[:, 0]
@@ -0,0 +1,144 @@
1
+ """
2
+ Factory module
3
+ """
4
+
5
+ import json
6
+ import os
7
+
8
+ from huggingface_hub.errors import HFValidationError
9
+ from transformers.utils import cached_file
10
+
11
+ from .base import Pooling
12
+ from .cls import ClsPooling
13
+ from .late import LatePooling
14
+ from .mean import MeanPooling
15
+
16
+
17
+ class PoolingFactory:
18
+ """
19
+ Method to create pooling models.
20
+ """
21
+
22
+ @staticmethod
23
+ def create(config):
24
+ """
25
+ Create a Pooling model.
26
+
27
+ Args:
28
+ config: pooling configuration
29
+
30
+ Returns:
31
+ Pooling
32
+ """
33
+
34
+ # Unpack parameters
35
+ method, path, device, tokenizer, maxlength, modelargs = [
36
+ config.get(x) for x in ["method", "path", "device", "tokenizer", "maxlength", "modelargs"]
37
+ ]
38
+
39
+ # Derive maxlength, if applicable
40
+ maxlength = PoolingFactory.maxlength(path) if isinstance(maxlength, bool) and maxlength else maxlength
41
+
42
+ # Default pooling returns hidden state
43
+ if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)) or method == "pooling":
44
+ return Pooling(path, device, tokenizer, maxlength, modelargs)
45
+
46
+ # Derive pooling method if it's not specified and path is a string
47
+ if (not method or method not in ("clspooling", "meanpooling", "latepooling")) and isinstance(path, str):
48
+ method = PoolingFactory.method(path)
49
+
50
+ # Check for cls pooling
51
+ if method == "clspooling":
52
+ return ClsPooling(path, device, tokenizer, maxlength, modelargs)
53
+
54
+ # Check for late pooling
55
+ if method == "latepooling":
56
+ return LatePooling(path, device, tokenizer, maxlength, modelargs)
57
+
58
+ # Default to mean pooling
59
+ return MeanPooling(path, device, tokenizer, maxlength, modelargs)
60
+
61
+ @staticmethod
62
+ def method(path):
63
+ """
64
+ Determines the pooling method using the sentence transformers pooling config.
65
+
66
+ Args:
67
+ path: model path
68
+
69
+ Returns:
70
+ pooling method
71
+ """
72
+
73
+ # Default method
74
+ method = "meanpooling"
75
+
76
+ # Load 1_Pooling/config.json file
77
+ config = PoolingFactory.load(path, "1_Pooling/config.json")
78
+
79
+ # Set to CLS pooling if it's enabled and mean pooling is disabled
80
+ if config and config["pooling_mode_cls_token"] and not config["pooling_mode_mean_tokens"]:
81
+ method = "clspooling"
82
+
83
+ # Check for late interaction pooling
84
+ if not config:
85
+ # Load 1_Dense/config.json
86
+ config = PoolingFactory.load(path, "1_Dense/config.json")
87
+ if config:
88
+ method = "latepooling"
89
+
90
+ # Load config.json and check architecture
91
+ else:
92
+ config = PoolingFactory.load(path, "config.json")
93
+ if config and "HF_ColBERT" in config.get("architectures", []):
94
+ method = "latepooling"
95
+
96
+ return method
97
+
98
+ @staticmethod
99
+ def maxlength(path):
100
+ """
101
+ Reads the max_seq_length parameter from sentence transformers config.
102
+
103
+ Args:
104
+ path: model path
105
+
106
+ Returns:
107
+ max sequence length
108
+ """
109
+
110
+ # Default length is unset
111
+ maxlength = None
112
+
113
+ # Read max_seq_length from sentence_bert_config.json
114
+ config = PoolingFactory.load(path, "sentence_bert_config.json")
115
+ maxlength = config.get("max_seq_length") if config else maxlength
116
+
117
+ return maxlength
118
+
119
+ @staticmethod
120
+ def load(path, name):
121
+ """
122
+ Loads a JSON config file from the Hugging Face Hub.
123
+
124
+ Args:
125
+ path: model path
126
+ name: file to load
127
+
128
+ Returns:
129
+ config
130
+ """
131
+
132
+ # Download file and parse JSON
133
+ config = None
134
+ try:
135
+ path = cached_file(path_or_repo_id=path, filename=name)
136
+ if path:
137
+ with open(path, encoding="utf-8") as f:
138
+ config = json.load(f)
139
+
140
+ # Ignore this error - invalid repo or directory
141
+ except (HFValidationError, OSError):
142
+ pass
143
+
144
+ return config
@@ -0,0 +1,173 @@
1
+ """
2
+ Late module
3
+ """
4
+
5
+ import json
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from huggingface_hub.errors import HFValidationError
11
+ from safetensors import safe_open
12
+ from torch import nn
13
+ from transformers.utils import cached_file
14
+
15
+ from .base import Pooling
16
+ from .muvera import Muvera
17
+
18
+
19
+ class LatePooling(Pooling):
20
+ """
21
+ Builds late pooled vectors using outputs from a transformers model.
22
+ """
23
+
24
+ def __init__(self, path, device, tokenizer=None, maxlength=None, modelargs=None):
25
+ # Check if fixed dimensional encoder is enabled
26
+ modelargs = modelargs.copy() if modelargs else {}
27
+ muvera = modelargs.pop("muvera", {})
28
+ self.encoder = Muvera(**muvera) if muvera is not None else None
29
+
30
+ # Call parent initialization
31
+ super().__init__(path, device, tokenizer, maxlength, modelargs)
32
+
33
+ # Get linear weights path
34
+ config = self.load(path, "1_Dense/config.json")
35
+ if config:
36
+ # PyLate weights format
37
+ name = "1_Dense/model.safetensors"
38
+ else:
39
+ # Stanford weights format
40
+ name = "model.safetensors"
41
+
42
+ # Read model settings
43
+ self.qprefix, self.qlength, self.dprefix, self.dlength = self.settings(path, config)
44
+
45
+ # Load linear layer
46
+ path = cached_file(path_or_repo_id=path, filename=name)
47
+ with safe_open(filename=path, framework="pt") as f:
48
+ weights = f.get_tensor("linear.weight")
49
+
50
+ # Load weights into linear layer
51
+ self.linear = nn.Linear(weights.shape[1], weights.shape[0], bias=False, device=self.device, dtype=weights.dtype)
52
+ with torch.no_grad():
53
+ self.linear.weight.copy_(weights)
54
+
55
+ def forward(self, **inputs):
56
+ """
57
+ Runs late pooling on token embeddings.
58
+
59
+ Args:
60
+ inputs: model inputs
61
+
62
+ Returns:
63
+ Late pooled embeddings using output token embeddings (i.e. last hidden state)
64
+ """
65
+
66
+ # Run through transformers model
67
+ tokens = super().forward(**inputs)
68
+
69
+ # Run through final linear layer and return
70
+ return self.linear(tokens)
71
+
72
+ def preencode(self, documents, category):
73
+ """
74
+ Apply prefixes and lengths to data.
75
+
76
+ Args:
77
+ documents: list of documents used to build embeddings
78
+ category: embeddings category (query or data)
79
+ """
80
+
81
+ results = []
82
+
83
+ # Apply prefix
84
+ for text in documents:
85
+ prefix = self.qprefix if category == "query" else self.dprefix
86
+ if prefix:
87
+ text = f"{prefix}{text}"
88
+
89
+ results.append(text)
90
+
91
+ # Set maxlength
92
+ maxlength = self.qlength if category == "query" else self.dlength
93
+ if maxlength:
94
+ self.maxlength = maxlength
95
+
96
+ return results
97
+
98
+ def postencode(self, results, category):
99
+ """
100
+ Normalizes and pads results.
101
+
102
+ Args:
103
+ results: input results
104
+
105
+ Returns:
106
+ normalized results with padding
107
+ """
108
+
109
+ length = 0
110
+ for vectors in results:
111
+ # Get max length
112
+ if vectors.shape[0] > length:
113
+ length = vectors.shape[0]
114
+
115
+ # Normalize vectors
116
+ vectors /= np.linalg.norm(vectors, axis=1)[:, np.newaxis]
117
+
118
+ # Pad values
119
+ data = []
120
+ for vectors in results:
121
+ data.append(np.pad(vectors, [(0, length - vectors.shape[0]), (0, 0)]))
122
+
123
+ # Build NumPy array
124
+ data = np.asarray(data)
125
+
126
+ # Apply fixed dimesional encoder, if necessary
127
+ return self.encoder(data, category) if self.encoder else data
128
+
129
+ def settings(self, path, config):
130
+ """
131
+ Reads model settings.
132
+
133
+ Args:
134
+ path: model path
135
+ config: PyLate model format if provided, otherwise read from Stanford format
136
+ """
137
+
138
+ if config:
139
+ # PyLate format
140
+ config = self.load(path, "config_sentence_transformers.json")
141
+ params = ["query_prefix", "query_length", "document_prefix", "document_length"]
142
+ else:
143
+ # Stanford format
144
+ config = self.load(path, "artifact.metadata")
145
+ params = ["query_token_id", "query_maxlen", "doc_token_id", "doc_maxlen"]
146
+
147
+ return [config.get(p) for p in params]
148
+
149
+ def load(self, path, name):
150
+ """
151
+ Loads a JSON config file from the Hugging Face Hub.
152
+
153
+ Args:
154
+ path: model path
155
+ name: file to load
156
+
157
+ Returns:
158
+ config
159
+ """
160
+
161
+ # Download file and parse JSON
162
+ config = None
163
+ try:
164
+ path = cached_file(path_or_repo_id=path, filename=name)
165
+ if path:
166
+ with open(path, encoding="utf-8") as f:
167
+ config = json.load(f)
168
+
169
+ # Ignore this error - invalid repo or directory
170
+ except (HFValidationError, OSError):
171
+ pass
172
+
173
+ return config
@@ -0,0 +1,33 @@
1
+ """
2
+ Mean module
3
+ """
4
+
5
+ import torch
6
+
7
+ from .base import Pooling
8
+
9
+
10
+ class MeanPooling(Pooling):
11
+ """
12
+ Builds mean pooled vectors usings outputs from a transformers model.
13
+ """
14
+
15
+ def forward(self, **inputs):
16
+ """
17
+ Runs mean pooling on token embeddings taking the input mask into account.
18
+
19
+ Args:
20
+ inputs: model inputs
21
+
22
+ Returns:
23
+ mean pooled embeddings using output token embeddings (i.e. last hidden state)
24
+ """
25
+
26
+ # Run through transformers model
27
+ tokens = super().forward(**inputs)
28
+ mask = inputs["attention_mask"]
29
+
30
+ # Mean pooling
31
+ # pylint: disable=E1101
32
+ mask = mask.unsqueeze(-1).expand(tokens.size()).float()
33
+ return torch.sum(tokens * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
@@ -0,0 +1,164 @@
1
+ """
2
+ Muvera module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+
8
+ class Muvera:
9
+ """
10
+ Implements the MUVERA (Multi-Vector Retrieval via Fixed Dimensional Encodings) algorithm. This reduces late interaction multi-vector
11
+ outputs to a single fixed vector.
12
+
13
+ The size of the output vectors are set using the following parameters
14
+
15
+ output dimensions = repetitions * 2^hashes * projected
16
+
17
+ For example, the default parameters create vectors with the following output dimensions.
18
+
19
+ output dimensions = 20 * 2^5 * 16 = 10240
20
+
21
+ This code is based on the following:
22
+ - Paper: https://arxiv.org/abs/2405.19504
23
+ - GitHub: https://github.com/google/graph-mining/tree/main/sketching/point_cloud
24
+ - Python port of the original C++ code: https://github.com/sigridjineth/muvera-py
25
+ """
26
+
27
+ def __init__(self, repetitions=20, hashes=5, projection=16, seed=42):
28
+ """
29
+ Creates a Muvera instance.
30
+
31
+ Args:
32
+ repetitions: number of iterations
33
+ hashes: number of simhash partitions as 2^hashes
34
+ projection: dimensionality reduction, uses an identity projection when set to None
35
+ seed: random seed
36
+ """
37
+
38
+ # Number of repetitions
39
+ self.repetitions = repetitions
40
+
41
+ # Number of simhash projections
42
+ self.hashes = hashes
43
+
44
+ # Optional number of projected dimensions
45
+ self.projection = projection
46
+
47
+ # Seed
48
+ self.seed = seed
49
+
50
+ def __call__(self, data, category):
51
+ """
52
+ Transforms a list of multi-vector collections into single fixed vector outputs.
53
+
54
+ Args:
55
+ data: array of multi-vector vectors
56
+ category: embeddings category (query or data)
57
+ """
58
+
59
+ # Get stats
60
+ dimension, length = data[0].shape[1], len(data)
61
+
62
+ # Determine projection dimension
63
+ identity = not self.projection
64
+ projection = dimension if identity else self.projection
65
+
66
+ # Number of simhash partitions
67
+ partitions = 2**self.hashes
68
+
69
+ # Document tracking
70
+ lengths = np.array([len(doc) for doc in data], dtype=np.int32)
71
+ total = np.sum(lengths)
72
+ documents = np.repeat(np.arange(length), lengths)
73
+
74
+ # Stack all vectors
75
+ points = np.vstack(data).astype(np.float32)
76
+
77
+ # Output vectors
78
+ size = self.repetitions * partitions * projection
79
+ vectors = np.zeros((length, size), dtype=np.float32)
80
+
81
+ # Process each repetition
82
+ for number in range(self.repetitions):
83
+ seed = self.seed + number
84
+
85
+ # Calculate the simhash
86
+ sketches = points @ self.random(dimension, self.hashes, seed)
87
+
88
+ # Dimensionality reduction, if necessary
89
+ projected = points if identity else (points @ self.reducer(dimension, projection, seed))
90
+
91
+ # Get partition indices
92
+ bits = (sketches > 0).astype(np.uint32)
93
+ indices = np.zeros(total, dtype=np.uint32)
94
+
95
+ # Calculate vector indices
96
+ for x in range(self.hashes):
97
+ indices = (indices << 1) + (bits[:, x] ^ (indices & 1))
98
+
99
+ # Initialize storage
100
+ fdesum = np.zeros((length * partitions * projection,), dtype=np.float32)
101
+ counts = np.zeros((length, partitions), dtype=np.int32)
102
+
103
+ # Count vectors per partition per document
104
+ np.add.at(counts, (documents, indices), 1)
105
+
106
+ # Aggregate vectors using flattened indexing for efficiency
107
+ part = documents * partitions + indices
108
+ base = part * projection
109
+
110
+ for d in range(projection):
111
+ flat = base + d
112
+ np.add.at(fdesum, flat, projected[:, d])
113
+
114
+ # Reshape for easier manipulation
115
+ # pylint: disable=E1121
116
+ fdesum = fdesum.reshape(length, partitions, projection)
117
+
118
+ # Convert sums to averages for data category
119
+ if category == "data":
120
+ # Safe division (avoid divide by zero)
121
+ counts = counts[:, :, np.newaxis]
122
+ np.divide(fdesum, counts, out=fdesum, where=counts > 0)
123
+
124
+ # Save results
125
+ start = number * partitions * projection
126
+ vectors[:, start : start + partitions * projection] = fdesum.reshape(length, -1)
127
+
128
+ return vectors
129
+
130
+ def random(self, dimension, projection, seed):
131
+ """
132
+ Generates a random matrix for simhash projections.
133
+
134
+ Args:
135
+ dimensions: number of dimensions for input vectors
136
+ projections: number of projection dimensions
137
+ seed: random seed
138
+
139
+ Returns:
140
+ random matrix for simhash projections
141
+ """
142
+
143
+ rng = np.random.default_rng(seed)
144
+ return rng.normal(loc=0.0, scale=1.0, size=(dimension, projection)).astype(np.float32)
145
+
146
+ def reducer(self, dimension, projection, seed):
147
+ """
148
+ Generates a random matrix for dimensionality reduction using the AMS sketch algorithm.
149
+
150
+ Args:
151
+ dimension: number of input dimensions
152
+ projected: number of dimensions to project inputs to
153
+
154
+ Returns:
155
+ Dimensionality reduced matrix
156
+ """
157
+
158
+ rng = np.random.default_rng(seed)
159
+ out = np.zeros((dimension, projection), dtype=np.float32)
160
+ indices = rng.integers(0, projection, size=dimension)
161
+ signs = rng.choice([-1.0, 1.0], size=dimension)
162
+ out[np.arange(dimension), indices] = signs
163
+
164
+ return out