mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,141 @@
|
|
1
|
+
"""
|
2
|
+
Pooling module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from torch import nn
|
9
|
+
|
10
|
+
from ..models import Models
|
11
|
+
|
12
|
+
|
13
|
+
class Pooling(nn.Module):
|
14
|
+
"""
|
15
|
+
Builds pooled vectors usings outputs from a transformers model.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, path, device, tokenizer=None, maxlength=None, modelargs=None):
|
19
|
+
"""
|
20
|
+
Creates a new Pooling model.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
path: path to model, accepts Hugging Face model hub id or local path
|
24
|
+
device: tensor device id
|
25
|
+
tokenizer: optional path to tokenizer
|
26
|
+
maxlength: max sequence length
|
27
|
+
modelargs: additional model arguments
|
28
|
+
"""
|
29
|
+
|
30
|
+
super().__init__()
|
31
|
+
|
32
|
+
self.model = Models.load(path, modelargs=modelargs)
|
33
|
+
self.tokenizer = Models.tokenizer(tokenizer if tokenizer else path)
|
34
|
+
self.device = Models.device(device)
|
35
|
+
|
36
|
+
# Detect unbounded tokenizer typically found in older models
|
37
|
+
Models.checklength(self.model, self.tokenizer)
|
38
|
+
|
39
|
+
# Set max length
|
40
|
+
self.maxlength = maxlength if maxlength else self.tokenizer.model_max_length if self.tokenizer.model_max_length != int(1e30) else None
|
41
|
+
|
42
|
+
# Move to device
|
43
|
+
self.to(self.device)
|
44
|
+
|
45
|
+
def encode(self, documents, batch=32, category=None):
|
46
|
+
"""
|
47
|
+
Builds an array of pooled embeddings for documents.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
documents: list of documents used to build embeddings
|
51
|
+
batch: model batch size
|
52
|
+
category: embeddings category (query or data)
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
pooled embeddings
|
56
|
+
"""
|
57
|
+
|
58
|
+
# Split documents into batches and process
|
59
|
+
results = []
|
60
|
+
|
61
|
+
# Apply pre encoding transformation logic
|
62
|
+
documents = self.preencode(documents, category)
|
63
|
+
|
64
|
+
# Sort document indices from largest to smallest to enable efficient batching
|
65
|
+
# This performance tweak matches logic in sentence-transformers
|
66
|
+
lengths = np.argsort([-len(x) if x else 0 for x in documents])
|
67
|
+
documents = [documents[x] for x in lengths]
|
68
|
+
|
69
|
+
for chunk in self.chunk(documents, batch):
|
70
|
+
# Tokenize input
|
71
|
+
inputs = self.tokenizer(chunk, padding=True, truncation="longest_first", return_tensors="pt", max_length=self.maxlength)
|
72
|
+
|
73
|
+
# Move inputs to device
|
74
|
+
inputs = inputs.to(self.device)
|
75
|
+
|
76
|
+
# Run inputs through model
|
77
|
+
with torch.no_grad():
|
78
|
+
outputs = self.forward(**inputs)
|
79
|
+
|
80
|
+
# Add batch result
|
81
|
+
results.extend(outputs.cpu().to(torch.float32).numpy())
|
82
|
+
|
83
|
+
# Apply post encoding transformation logic
|
84
|
+
results = self.postencode(results, category)
|
85
|
+
|
86
|
+
# Restore original order and return array
|
87
|
+
return np.asarray([results[x] for x in np.argsort(lengths)])
|
88
|
+
|
89
|
+
def chunk(self, texts, size):
|
90
|
+
"""
|
91
|
+
Splits texts into separate batch sizes specified by size.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
texts: text elements
|
95
|
+
size: batch size
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
list of evenly sized batches with the last batch having the remaining elements
|
99
|
+
"""
|
100
|
+
|
101
|
+
return [texts[x : x + size] for x in range(0, len(texts), size)]
|
102
|
+
|
103
|
+
def forward(self, **inputs):
|
104
|
+
"""
|
105
|
+
Runs inputs through transformers model and returns outputs.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
inputs: model inputs
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
model outputs
|
112
|
+
"""
|
113
|
+
|
114
|
+
return self.model(**inputs)[0]
|
115
|
+
|
116
|
+
# pylint: disable=W0613
|
117
|
+
def preencode(self, documents, category):
|
118
|
+
"""
|
119
|
+
Applies pre encoding transformation logic.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
documents: list of documents used to build embeddings
|
123
|
+
category: embeddings category (query or data)
|
124
|
+
"""
|
125
|
+
|
126
|
+
return documents
|
127
|
+
|
128
|
+
# pylint: disable=W0613
|
129
|
+
def postencode(self, results, category):
|
130
|
+
"""
|
131
|
+
Applies post encoding transformation logic.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
results: list of results
|
135
|
+
category: embeddings category (query or data)
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
results with transformation logic applied
|
139
|
+
"""
|
140
|
+
|
141
|
+
return results
|
@@ -0,0 +1,28 @@
|
|
1
|
+
"""
|
2
|
+
CLS module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Pooling
|
6
|
+
|
7
|
+
|
8
|
+
class ClsPooling(Pooling):
|
9
|
+
"""
|
10
|
+
Builds CLS pooled vectors using outputs from a transformers model.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def forward(self, **inputs):
|
14
|
+
"""
|
15
|
+
Runs CLS pooling on token embeddings.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
inputs: model inputs
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
CLS pooled embeddings using output token embeddings (i.e. last hidden state)
|
22
|
+
"""
|
23
|
+
|
24
|
+
# Run through transformers model
|
25
|
+
tokens = super().forward(**inputs)
|
26
|
+
|
27
|
+
# CLS token pooling
|
28
|
+
return tokens[:, 0]
|
@@ -0,0 +1,144 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
import os
|
7
|
+
|
8
|
+
from huggingface_hub.errors import HFValidationError
|
9
|
+
from transformers.utils import cached_file
|
10
|
+
|
11
|
+
from .base import Pooling
|
12
|
+
from .cls import ClsPooling
|
13
|
+
from .late import LatePooling
|
14
|
+
from .mean import MeanPooling
|
15
|
+
|
16
|
+
|
17
|
+
class PoolingFactory:
|
18
|
+
"""
|
19
|
+
Method to create pooling models.
|
20
|
+
"""
|
21
|
+
|
22
|
+
@staticmethod
|
23
|
+
def create(config):
|
24
|
+
"""
|
25
|
+
Create a Pooling model.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
config: pooling configuration
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
Pooling
|
32
|
+
"""
|
33
|
+
|
34
|
+
# Unpack parameters
|
35
|
+
method, path, device, tokenizer, maxlength, modelargs = [
|
36
|
+
config.get(x) for x in ["method", "path", "device", "tokenizer", "maxlength", "modelargs"]
|
37
|
+
]
|
38
|
+
|
39
|
+
# Derive maxlength, if applicable
|
40
|
+
maxlength = PoolingFactory.maxlength(path) if isinstance(maxlength, bool) and maxlength else maxlength
|
41
|
+
|
42
|
+
# Default pooling returns hidden state
|
43
|
+
if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)) or method == "pooling":
|
44
|
+
return Pooling(path, device, tokenizer, maxlength, modelargs)
|
45
|
+
|
46
|
+
# Derive pooling method if it's not specified and path is a string
|
47
|
+
if (not method or method not in ("clspooling", "meanpooling", "latepooling")) and isinstance(path, str):
|
48
|
+
method = PoolingFactory.method(path)
|
49
|
+
|
50
|
+
# Check for cls pooling
|
51
|
+
if method == "clspooling":
|
52
|
+
return ClsPooling(path, device, tokenizer, maxlength, modelargs)
|
53
|
+
|
54
|
+
# Check for late pooling
|
55
|
+
if method == "latepooling":
|
56
|
+
return LatePooling(path, device, tokenizer, maxlength, modelargs)
|
57
|
+
|
58
|
+
# Default to mean pooling
|
59
|
+
return MeanPooling(path, device, tokenizer, maxlength, modelargs)
|
60
|
+
|
61
|
+
@staticmethod
|
62
|
+
def method(path):
|
63
|
+
"""
|
64
|
+
Determines the pooling method using the sentence transformers pooling config.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
path: model path
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
pooling method
|
71
|
+
"""
|
72
|
+
|
73
|
+
# Default method
|
74
|
+
method = "meanpooling"
|
75
|
+
|
76
|
+
# Load 1_Pooling/config.json file
|
77
|
+
config = PoolingFactory.load(path, "1_Pooling/config.json")
|
78
|
+
|
79
|
+
# Set to CLS pooling if it's enabled and mean pooling is disabled
|
80
|
+
if config and config["pooling_mode_cls_token"] and not config["pooling_mode_mean_tokens"]:
|
81
|
+
method = "clspooling"
|
82
|
+
|
83
|
+
# Check for late interaction pooling
|
84
|
+
if not config:
|
85
|
+
# Load 1_Dense/config.json
|
86
|
+
config = PoolingFactory.load(path, "1_Dense/config.json")
|
87
|
+
if config:
|
88
|
+
method = "latepooling"
|
89
|
+
|
90
|
+
# Load config.json and check architecture
|
91
|
+
else:
|
92
|
+
config = PoolingFactory.load(path, "config.json")
|
93
|
+
if config and "HF_ColBERT" in config.get("architectures", []):
|
94
|
+
method = "latepooling"
|
95
|
+
|
96
|
+
return method
|
97
|
+
|
98
|
+
@staticmethod
|
99
|
+
def maxlength(path):
|
100
|
+
"""
|
101
|
+
Reads the max_seq_length parameter from sentence transformers config.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
path: model path
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
max sequence length
|
108
|
+
"""
|
109
|
+
|
110
|
+
# Default length is unset
|
111
|
+
maxlength = None
|
112
|
+
|
113
|
+
# Read max_seq_length from sentence_bert_config.json
|
114
|
+
config = PoolingFactory.load(path, "sentence_bert_config.json")
|
115
|
+
maxlength = config.get("max_seq_length") if config else maxlength
|
116
|
+
|
117
|
+
return maxlength
|
118
|
+
|
119
|
+
@staticmethod
|
120
|
+
def load(path, name):
|
121
|
+
"""
|
122
|
+
Loads a JSON config file from the Hugging Face Hub.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
path: model path
|
126
|
+
name: file to load
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
config
|
130
|
+
"""
|
131
|
+
|
132
|
+
# Download file and parse JSON
|
133
|
+
config = None
|
134
|
+
try:
|
135
|
+
path = cached_file(path_or_repo_id=path, filename=name)
|
136
|
+
if path:
|
137
|
+
with open(path, encoding="utf-8") as f:
|
138
|
+
config = json.load(f)
|
139
|
+
|
140
|
+
# Ignore this error - invalid repo or directory
|
141
|
+
except (HFValidationError, OSError):
|
142
|
+
pass
|
143
|
+
|
144
|
+
return config
|
@@ -0,0 +1,173 @@
|
|
1
|
+
"""
|
2
|
+
Late module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from huggingface_hub.errors import HFValidationError
|
11
|
+
from safetensors import safe_open
|
12
|
+
from torch import nn
|
13
|
+
from transformers.utils import cached_file
|
14
|
+
|
15
|
+
from .base import Pooling
|
16
|
+
from .muvera import Muvera
|
17
|
+
|
18
|
+
|
19
|
+
class LatePooling(Pooling):
|
20
|
+
"""
|
21
|
+
Builds late pooled vectors using outputs from a transformers model.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, path, device, tokenizer=None, maxlength=None, modelargs=None):
|
25
|
+
# Check if fixed dimensional encoder is enabled
|
26
|
+
modelargs = modelargs.copy() if modelargs else {}
|
27
|
+
muvera = modelargs.pop("muvera", {})
|
28
|
+
self.encoder = Muvera(**muvera) if muvera is not None else None
|
29
|
+
|
30
|
+
# Call parent initialization
|
31
|
+
super().__init__(path, device, tokenizer, maxlength, modelargs)
|
32
|
+
|
33
|
+
# Get linear weights path
|
34
|
+
config = self.load(path, "1_Dense/config.json")
|
35
|
+
if config:
|
36
|
+
# PyLate weights format
|
37
|
+
name = "1_Dense/model.safetensors"
|
38
|
+
else:
|
39
|
+
# Stanford weights format
|
40
|
+
name = "model.safetensors"
|
41
|
+
|
42
|
+
# Read model settings
|
43
|
+
self.qprefix, self.qlength, self.dprefix, self.dlength = self.settings(path, config)
|
44
|
+
|
45
|
+
# Load linear layer
|
46
|
+
path = cached_file(path_or_repo_id=path, filename=name)
|
47
|
+
with safe_open(filename=path, framework="pt") as f:
|
48
|
+
weights = f.get_tensor("linear.weight")
|
49
|
+
|
50
|
+
# Load weights into linear layer
|
51
|
+
self.linear = nn.Linear(weights.shape[1], weights.shape[0], bias=False, device=self.device, dtype=weights.dtype)
|
52
|
+
with torch.no_grad():
|
53
|
+
self.linear.weight.copy_(weights)
|
54
|
+
|
55
|
+
def forward(self, **inputs):
|
56
|
+
"""
|
57
|
+
Runs late pooling on token embeddings.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
inputs: model inputs
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Late pooled embeddings using output token embeddings (i.e. last hidden state)
|
64
|
+
"""
|
65
|
+
|
66
|
+
# Run through transformers model
|
67
|
+
tokens = super().forward(**inputs)
|
68
|
+
|
69
|
+
# Run through final linear layer and return
|
70
|
+
return self.linear(tokens)
|
71
|
+
|
72
|
+
def preencode(self, documents, category):
|
73
|
+
"""
|
74
|
+
Apply prefixes and lengths to data.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
documents: list of documents used to build embeddings
|
78
|
+
category: embeddings category (query or data)
|
79
|
+
"""
|
80
|
+
|
81
|
+
results = []
|
82
|
+
|
83
|
+
# Apply prefix
|
84
|
+
for text in documents:
|
85
|
+
prefix = self.qprefix if category == "query" else self.dprefix
|
86
|
+
if prefix:
|
87
|
+
text = f"{prefix}{text}"
|
88
|
+
|
89
|
+
results.append(text)
|
90
|
+
|
91
|
+
# Set maxlength
|
92
|
+
maxlength = self.qlength if category == "query" else self.dlength
|
93
|
+
if maxlength:
|
94
|
+
self.maxlength = maxlength
|
95
|
+
|
96
|
+
return results
|
97
|
+
|
98
|
+
def postencode(self, results, category):
|
99
|
+
"""
|
100
|
+
Normalizes and pads results.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
results: input results
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
normalized results with padding
|
107
|
+
"""
|
108
|
+
|
109
|
+
length = 0
|
110
|
+
for vectors in results:
|
111
|
+
# Get max length
|
112
|
+
if vectors.shape[0] > length:
|
113
|
+
length = vectors.shape[0]
|
114
|
+
|
115
|
+
# Normalize vectors
|
116
|
+
vectors /= np.linalg.norm(vectors, axis=1)[:, np.newaxis]
|
117
|
+
|
118
|
+
# Pad values
|
119
|
+
data = []
|
120
|
+
for vectors in results:
|
121
|
+
data.append(np.pad(vectors, [(0, length - vectors.shape[0]), (0, 0)]))
|
122
|
+
|
123
|
+
# Build NumPy array
|
124
|
+
data = np.asarray(data)
|
125
|
+
|
126
|
+
# Apply fixed dimesional encoder, if necessary
|
127
|
+
return self.encoder(data, category) if self.encoder else data
|
128
|
+
|
129
|
+
def settings(self, path, config):
|
130
|
+
"""
|
131
|
+
Reads model settings.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
path: model path
|
135
|
+
config: PyLate model format if provided, otherwise read from Stanford format
|
136
|
+
"""
|
137
|
+
|
138
|
+
if config:
|
139
|
+
# PyLate format
|
140
|
+
config = self.load(path, "config_sentence_transformers.json")
|
141
|
+
params = ["query_prefix", "query_length", "document_prefix", "document_length"]
|
142
|
+
else:
|
143
|
+
# Stanford format
|
144
|
+
config = self.load(path, "artifact.metadata")
|
145
|
+
params = ["query_token_id", "query_maxlen", "doc_token_id", "doc_maxlen"]
|
146
|
+
|
147
|
+
return [config.get(p) for p in params]
|
148
|
+
|
149
|
+
def load(self, path, name):
|
150
|
+
"""
|
151
|
+
Loads a JSON config file from the Hugging Face Hub.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
path: model path
|
155
|
+
name: file to load
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
config
|
159
|
+
"""
|
160
|
+
|
161
|
+
# Download file and parse JSON
|
162
|
+
config = None
|
163
|
+
try:
|
164
|
+
path = cached_file(path_or_repo_id=path, filename=name)
|
165
|
+
if path:
|
166
|
+
with open(path, encoding="utf-8") as f:
|
167
|
+
config = json.load(f)
|
168
|
+
|
169
|
+
# Ignore this error - invalid repo or directory
|
170
|
+
except (HFValidationError, OSError):
|
171
|
+
pass
|
172
|
+
|
173
|
+
return config
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""
|
2
|
+
Mean module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from .base import Pooling
|
8
|
+
|
9
|
+
|
10
|
+
class MeanPooling(Pooling):
|
11
|
+
"""
|
12
|
+
Builds mean pooled vectors usings outputs from a transformers model.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def forward(self, **inputs):
|
16
|
+
"""
|
17
|
+
Runs mean pooling on token embeddings taking the input mask into account.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
inputs: model inputs
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
mean pooled embeddings using output token embeddings (i.e. last hidden state)
|
24
|
+
"""
|
25
|
+
|
26
|
+
# Run through transformers model
|
27
|
+
tokens = super().forward(**inputs)
|
28
|
+
mask = inputs["attention_mask"]
|
29
|
+
|
30
|
+
# Mean pooling
|
31
|
+
# pylint: disable=E1101
|
32
|
+
mask = mask.unsqueeze(-1).expand(tokens.size()).float()
|
33
|
+
return torch.sum(tokens * mask, 1) / torch.clamp(mask.sum(1), min=1e-9)
|
@@ -0,0 +1,164 @@
|
|
1
|
+
"""
|
2
|
+
Muvera module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
|
8
|
+
class Muvera:
|
9
|
+
"""
|
10
|
+
Implements the MUVERA (Multi-Vector Retrieval via Fixed Dimensional Encodings) algorithm. This reduces late interaction multi-vector
|
11
|
+
outputs to a single fixed vector.
|
12
|
+
|
13
|
+
The size of the output vectors are set using the following parameters
|
14
|
+
|
15
|
+
output dimensions = repetitions * 2^hashes * projected
|
16
|
+
|
17
|
+
For example, the default parameters create vectors with the following output dimensions.
|
18
|
+
|
19
|
+
output dimensions = 20 * 2^5 * 16 = 10240
|
20
|
+
|
21
|
+
This code is based on the following:
|
22
|
+
- Paper: https://arxiv.org/abs/2405.19504
|
23
|
+
- GitHub: https://github.com/google/graph-mining/tree/main/sketching/point_cloud
|
24
|
+
- Python port of the original C++ code: https://github.com/sigridjineth/muvera-py
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, repetitions=20, hashes=5, projection=16, seed=42):
|
28
|
+
"""
|
29
|
+
Creates a Muvera instance.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
repetitions: number of iterations
|
33
|
+
hashes: number of simhash partitions as 2^hashes
|
34
|
+
projection: dimensionality reduction, uses an identity projection when set to None
|
35
|
+
seed: random seed
|
36
|
+
"""
|
37
|
+
|
38
|
+
# Number of repetitions
|
39
|
+
self.repetitions = repetitions
|
40
|
+
|
41
|
+
# Number of simhash projections
|
42
|
+
self.hashes = hashes
|
43
|
+
|
44
|
+
# Optional number of projected dimensions
|
45
|
+
self.projection = projection
|
46
|
+
|
47
|
+
# Seed
|
48
|
+
self.seed = seed
|
49
|
+
|
50
|
+
def __call__(self, data, category):
|
51
|
+
"""
|
52
|
+
Transforms a list of multi-vector collections into single fixed vector outputs.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
data: array of multi-vector vectors
|
56
|
+
category: embeddings category (query or data)
|
57
|
+
"""
|
58
|
+
|
59
|
+
# Get stats
|
60
|
+
dimension, length = data[0].shape[1], len(data)
|
61
|
+
|
62
|
+
# Determine projection dimension
|
63
|
+
identity = not self.projection
|
64
|
+
projection = dimension if identity else self.projection
|
65
|
+
|
66
|
+
# Number of simhash partitions
|
67
|
+
partitions = 2**self.hashes
|
68
|
+
|
69
|
+
# Document tracking
|
70
|
+
lengths = np.array([len(doc) for doc in data], dtype=np.int32)
|
71
|
+
total = np.sum(lengths)
|
72
|
+
documents = np.repeat(np.arange(length), lengths)
|
73
|
+
|
74
|
+
# Stack all vectors
|
75
|
+
points = np.vstack(data).astype(np.float32)
|
76
|
+
|
77
|
+
# Output vectors
|
78
|
+
size = self.repetitions * partitions * projection
|
79
|
+
vectors = np.zeros((length, size), dtype=np.float32)
|
80
|
+
|
81
|
+
# Process each repetition
|
82
|
+
for number in range(self.repetitions):
|
83
|
+
seed = self.seed + number
|
84
|
+
|
85
|
+
# Calculate the simhash
|
86
|
+
sketches = points @ self.random(dimension, self.hashes, seed)
|
87
|
+
|
88
|
+
# Dimensionality reduction, if necessary
|
89
|
+
projected = points if identity else (points @ self.reducer(dimension, projection, seed))
|
90
|
+
|
91
|
+
# Get partition indices
|
92
|
+
bits = (sketches > 0).astype(np.uint32)
|
93
|
+
indices = np.zeros(total, dtype=np.uint32)
|
94
|
+
|
95
|
+
# Calculate vector indices
|
96
|
+
for x in range(self.hashes):
|
97
|
+
indices = (indices << 1) + (bits[:, x] ^ (indices & 1))
|
98
|
+
|
99
|
+
# Initialize storage
|
100
|
+
fdesum = np.zeros((length * partitions * projection,), dtype=np.float32)
|
101
|
+
counts = np.zeros((length, partitions), dtype=np.int32)
|
102
|
+
|
103
|
+
# Count vectors per partition per document
|
104
|
+
np.add.at(counts, (documents, indices), 1)
|
105
|
+
|
106
|
+
# Aggregate vectors using flattened indexing for efficiency
|
107
|
+
part = documents * partitions + indices
|
108
|
+
base = part * projection
|
109
|
+
|
110
|
+
for d in range(projection):
|
111
|
+
flat = base + d
|
112
|
+
np.add.at(fdesum, flat, projected[:, d])
|
113
|
+
|
114
|
+
# Reshape for easier manipulation
|
115
|
+
# pylint: disable=E1121
|
116
|
+
fdesum = fdesum.reshape(length, partitions, projection)
|
117
|
+
|
118
|
+
# Convert sums to averages for data category
|
119
|
+
if category == "data":
|
120
|
+
# Safe division (avoid divide by zero)
|
121
|
+
counts = counts[:, :, np.newaxis]
|
122
|
+
np.divide(fdesum, counts, out=fdesum, where=counts > 0)
|
123
|
+
|
124
|
+
# Save results
|
125
|
+
start = number * partitions * projection
|
126
|
+
vectors[:, start : start + partitions * projection] = fdesum.reshape(length, -1)
|
127
|
+
|
128
|
+
return vectors
|
129
|
+
|
130
|
+
def random(self, dimension, projection, seed):
|
131
|
+
"""
|
132
|
+
Generates a random matrix for simhash projections.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
dimensions: number of dimensions for input vectors
|
136
|
+
projections: number of projection dimensions
|
137
|
+
seed: random seed
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
random matrix for simhash projections
|
141
|
+
"""
|
142
|
+
|
143
|
+
rng = np.random.default_rng(seed)
|
144
|
+
return rng.normal(loc=0.0, scale=1.0, size=(dimension, projection)).astype(np.float32)
|
145
|
+
|
146
|
+
def reducer(self, dimension, projection, seed):
|
147
|
+
"""
|
148
|
+
Generates a random matrix for dimensionality reduction using the AMS sketch algorithm.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
dimension: number of input dimensions
|
152
|
+
projected: number of dimensions to project inputs to
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
Dimensionality reduced matrix
|
156
|
+
"""
|
157
|
+
|
158
|
+
rng = np.random.default_rng(seed)
|
159
|
+
out = np.zeros((dimension, projection), dtype=np.float32)
|
160
|
+
indices = rng.integers(0, projection, size=dimension)
|
161
|
+
signs = rng.choice([-1.0, 1.0], size=dimension)
|
162
|
+
out[np.arange(dimension), indices] = signs
|
163
|
+
|
164
|
+
return out
|