mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,121 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ...util import Resolver
|
6
|
+
|
7
|
+
from .external import External
|
8
|
+
from .huggingface import HFVectors
|
9
|
+
from .litellm import LiteLLM
|
10
|
+
from .llama import LlamaCpp
|
11
|
+
from .m2v import Model2Vec
|
12
|
+
from .sbert import STVectors
|
13
|
+
from .words import WordVectors
|
14
|
+
|
15
|
+
|
16
|
+
class VectorsFactory:
|
17
|
+
"""
|
18
|
+
Methods to create dense vector models.
|
19
|
+
"""
|
20
|
+
|
21
|
+
@staticmethod
|
22
|
+
def create(config, scoring=None, models=None):
|
23
|
+
"""
|
24
|
+
Create a Vectors model instance.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
config: vector configuration
|
28
|
+
scoring: scoring instance
|
29
|
+
models: models cache
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
Vectors
|
33
|
+
"""
|
34
|
+
|
35
|
+
# Determine vector method
|
36
|
+
method = VectorsFactory.method(config)
|
37
|
+
|
38
|
+
# External vectors
|
39
|
+
if method == "external":
|
40
|
+
return External(config, scoring, models)
|
41
|
+
|
42
|
+
# LiteLLM vectors
|
43
|
+
if method == "litellm":
|
44
|
+
return LiteLLM(config, scoring, models)
|
45
|
+
|
46
|
+
# llama.cpp vectors
|
47
|
+
if method == "llama.cpp":
|
48
|
+
return LlamaCpp(config, scoring, models)
|
49
|
+
|
50
|
+
# Model2vec vectors
|
51
|
+
if method == "model2vec":
|
52
|
+
return Model2Vec(config, scoring, models)
|
53
|
+
|
54
|
+
# Sentence Transformers vectors
|
55
|
+
if method == "sentence-transformers":
|
56
|
+
return STVectors(config, scoring, models) if config and config.get("path") else None
|
57
|
+
|
58
|
+
# Word vectors
|
59
|
+
if method == "words":
|
60
|
+
return WordVectors(config, scoring, models)
|
61
|
+
|
62
|
+
# Transformers vectors
|
63
|
+
if HFVectors.ismethod(method):
|
64
|
+
return HFVectors(config, scoring, models) if config and config.get("path") else None
|
65
|
+
|
66
|
+
# Resolve custom method
|
67
|
+
return VectorsFactory.resolve(method, config, scoring, models) if method else None
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def method(config):
|
71
|
+
"""
|
72
|
+
Get or derive the vector method.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
config: vector configuration
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
vector method
|
79
|
+
"""
|
80
|
+
|
81
|
+
# Determine vector method
|
82
|
+
method = config.get("method")
|
83
|
+
path = config.get("path")
|
84
|
+
|
85
|
+
# Infer method from path, if blank
|
86
|
+
if not method:
|
87
|
+
if path:
|
88
|
+
if LiteLLM.ismodel(path):
|
89
|
+
method = "litellm"
|
90
|
+
elif LlamaCpp.ismodel(path):
|
91
|
+
method = "llama.cpp"
|
92
|
+
elif Model2Vec.ismodel(path):
|
93
|
+
method = "model2vec"
|
94
|
+
elif WordVectors.ismodel(path):
|
95
|
+
method = "words"
|
96
|
+
else:
|
97
|
+
method = "transformers"
|
98
|
+
elif config.get("transform"):
|
99
|
+
method = "external"
|
100
|
+
|
101
|
+
return method
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def resolve(backend, config, scoring, models):
|
105
|
+
"""
|
106
|
+
Attempt to resolve a custom backend.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
backend: backend class
|
110
|
+
config: vector configuration
|
111
|
+
scoring: scoring instance
|
112
|
+
models: models cache
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
Vectors
|
116
|
+
"""
|
117
|
+
|
118
|
+
try:
|
119
|
+
return Resolver()(backend)(config, scoring, models)
|
120
|
+
except Exception as e:
|
121
|
+
raise ImportError(f"Unable to resolve vectors backend: '{backend}'") from e
|
@@ -0,0 +1,44 @@
|
|
1
|
+
"""
|
2
|
+
Hugging Face module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ...models import Models, PoolingFactory
|
6
|
+
|
7
|
+
from ..base import Vectors
|
8
|
+
|
9
|
+
|
10
|
+
class HFVectors(Vectors):
|
11
|
+
"""
|
12
|
+
Builds vectors using the Hugging Face transformers library.
|
13
|
+
"""
|
14
|
+
|
15
|
+
@staticmethod
|
16
|
+
def ismethod(method):
|
17
|
+
"""
|
18
|
+
Checks if this method uses local transformers-based models.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
method: input method
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
True if this is a local transformers-based model, False otherwise
|
25
|
+
"""
|
26
|
+
|
27
|
+
return method in ("transformers", "pooling", "clspooling", "meanpooling")
|
28
|
+
|
29
|
+
def loadmodel(self, path):
|
30
|
+
# Build embeddings with transformers pooling
|
31
|
+
return PoolingFactory.create(
|
32
|
+
{
|
33
|
+
"method": self.config.get("method"),
|
34
|
+
"path": path,
|
35
|
+
"device": Models.deviceid(self.config.get("gpu", True)),
|
36
|
+
"tokenizer": self.config.get("tokenizer"),
|
37
|
+
"maxlength": self.config.get("maxlength"),
|
38
|
+
"modelargs": self.config.get("vectors", {}),
|
39
|
+
}
|
40
|
+
)
|
41
|
+
|
42
|
+
def encode(self, data, category=None):
|
43
|
+
# Encode data using vectors model
|
44
|
+
return self.model.encode(data, batch=self.encodebatch, category=category)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
"""
|
2
|
+
LiteLLM module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from transformers.utils import cached_file
|
8
|
+
|
9
|
+
# Conditional import
|
10
|
+
try:
|
11
|
+
import litellm as api
|
12
|
+
|
13
|
+
LITELLM = True
|
14
|
+
except ImportError:
|
15
|
+
LITELLM = False
|
16
|
+
|
17
|
+
from ..base import Vectors
|
18
|
+
|
19
|
+
|
20
|
+
class LiteLLM(Vectors):
|
21
|
+
"""
|
22
|
+
Builds vectors using an external embeddings API via LiteLLM.
|
23
|
+
"""
|
24
|
+
|
25
|
+
@staticmethod
|
26
|
+
def ismodel(path):
|
27
|
+
"""
|
28
|
+
Checks if path is a LiteLLM model.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
path: input path
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
True if this is a LiteLLM model, False otherwise
|
35
|
+
"""
|
36
|
+
|
37
|
+
# pylint: disable=W0702
|
38
|
+
if isinstance(path, str) and LITELLM:
|
39
|
+
debug = api.suppress_debug_info
|
40
|
+
try:
|
41
|
+
# Suppress debug messages for this test
|
42
|
+
api.suppress_debug_info = True
|
43
|
+
return api.get_llm_provider(path) and not LiteLLM.ishub(path)
|
44
|
+
except:
|
45
|
+
return False
|
46
|
+
finally:
|
47
|
+
# Restore debug info value to original value
|
48
|
+
api.suppress_debug_info = debug
|
49
|
+
|
50
|
+
return False
|
51
|
+
|
52
|
+
@staticmethod
|
53
|
+
def ishub(path):
|
54
|
+
"""
|
55
|
+
Checks if path is available on the HF Hub.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
input path
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
True if this is a model on the HF Hub
|
62
|
+
"""
|
63
|
+
|
64
|
+
# pylint: disable=W0702
|
65
|
+
try:
|
66
|
+
return cached_file(path_or_repo_id=path, filename="config.json") is not None if "/" in path else False
|
67
|
+
except:
|
68
|
+
return False
|
69
|
+
|
70
|
+
def __init__(self, config, scoring, models):
|
71
|
+
# Check before parent constructor since it calls loadmodel
|
72
|
+
if not LITELLM:
|
73
|
+
raise ImportError('LiteLLM is not available - install "vectors" extra to enable')
|
74
|
+
|
75
|
+
super().__init__(config, scoring, models)
|
76
|
+
|
77
|
+
def loadmodel(self, path):
|
78
|
+
return None
|
79
|
+
|
80
|
+
def encode(self, data, category=None):
|
81
|
+
# Call external embeddings API using LiteLLM
|
82
|
+
# Batching is handled server-side
|
83
|
+
response = api.embedding(model=self.config.get("path"), input=data, **self.config.get("vectors", {}))
|
84
|
+
|
85
|
+
# Read response into a NumPy array
|
86
|
+
return np.array([x["embedding"] for x in response.data], dtype=np.float32)
|
@@ -0,0 +1,84 @@
|
|
1
|
+
"""
|
2
|
+
Llama module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from huggingface_hub import hf_hub_download
|
10
|
+
|
11
|
+
# Conditional import
|
12
|
+
try:
|
13
|
+
from llama_cpp import Llama
|
14
|
+
|
15
|
+
LLAMA_CPP = True
|
16
|
+
except ImportError:
|
17
|
+
LLAMA_CPP = False
|
18
|
+
|
19
|
+
from ..base import Vectors
|
20
|
+
|
21
|
+
|
22
|
+
class LlamaCpp(Vectors):
|
23
|
+
"""
|
24
|
+
Builds vectors using llama.cpp.
|
25
|
+
"""
|
26
|
+
|
27
|
+
@staticmethod
|
28
|
+
def ismodel(path):
|
29
|
+
"""
|
30
|
+
Checks if path is a llama.cpp model.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
path: input path
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
True if this is a llama.cpp model, False otherwise
|
37
|
+
"""
|
38
|
+
|
39
|
+
return isinstance(path, str) and path.lower().endswith(".gguf")
|
40
|
+
|
41
|
+
def __init__(self, config, scoring, models):
|
42
|
+
# Check before parent constructor since it calls loadmodel
|
43
|
+
if not LLAMA_CPP:
|
44
|
+
raise ImportError('llama.cpp is not available - install "vectors" extra to enable')
|
45
|
+
|
46
|
+
super().__init__(config, scoring, models)
|
47
|
+
|
48
|
+
def loadmodel(self, path):
|
49
|
+
# Check if this is a local path, otherwise download from the HF Hub
|
50
|
+
path = path if os.path.exists(path) else self.download(path)
|
51
|
+
|
52
|
+
# Additional model arguments
|
53
|
+
modelargs = self.config.get("vectors", {})
|
54
|
+
|
55
|
+
# Default GPU layers if not already set
|
56
|
+
modelargs["n_gpu_layers"] = modelargs.get("n_gpu_layers", -1 if self.config.get("gpu", os.environ.get("LLAMA_NO_METAL") != "1") else 0)
|
57
|
+
|
58
|
+
# Create llama.cpp instance
|
59
|
+
return Llama(path, n_ctx=0, verbose=modelargs.pop("verbose", False), embedding=True, **modelargs)
|
60
|
+
|
61
|
+
def encode(self, data, category=None):
|
62
|
+
# Generate embeddings and return as a NumPy array
|
63
|
+
# llama-cpp-python has it's own batching built-in using n_batch parameter
|
64
|
+
return np.array(self.model.embed(data), dtype=np.float32)
|
65
|
+
|
66
|
+
def download(self, path):
|
67
|
+
"""
|
68
|
+
Downloads path from the Hugging Face Hub.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
path: full model path
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
local cached model path
|
75
|
+
"""
|
76
|
+
|
77
|
+
# Split into parts
|
78
|
+
parts = path.split("/")
|
79
|
+
|
80
|
+
# Calculate repo id split
|
81
|
+
repo = 2 if len(parts) > 2 else 1
|
82
|
+
|
83
|
+
# Download and cache file
|
84
|
+
return hf_hub_download(repo_id="/".join(parts[:repo]), filename="/".join(parts[repo:]))
|
@@ -0,0 +1,67 @@
|
|
1
|
+
"""
|
2
|
+
Model2Vec module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
|
7
|
+
from huggingface_hub.errors import HFValidationError
|
8
|
+
from transformers.utils import cached_file
|
9
|
+
|
10
|
+
# Conditional import
|
11
|
+
try:
|
12
|
+
from model2vec import StaticModel
|
13
|
+
|
14
|
+
MODEL2VEC = True
|
15
|
+
except ImportError:
|
16
|
+
MODEL2VEC = False
|
17
|
+
|
18
|
+
from ..base import Vectors
|
19
|
+
|
20
|
+
|
21
|
+
class Model2Vec(Vectors):
|
22
|
+
"""
|
23
|
+
Builds vectors using Model2Vec.
|
24
|
+
"""
|
25
|
+
|
26
|
+
@staticmethod
|
27
|
+
def ismodel(path):
|
28
|
+
"""
|
29
|
+
Checks if path is a Model2Vec model.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
path: input path
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
True if this is a Model2Vec model, False otherwise
|
36
|
+
"""
|
37
|
+
|
38
|
+
try:
|
39
|
+
# Download file and parse JSON
|
40
|
+
path = cached_file(path_or_repo_id=path, filename="config.json")
|
41
|
+
if path:
|
42
|
+
with open(path, encoding="utf-8") as f:
|
43
|
+
config = json.load(f)
|
44
|
+
return config.get("model_type") == "model2vec"
|
45
|
+
|
46
|
+
# Ignore this error - invalid repo or directory
|
47
|
+
except (HFValidationError, OSError):
|
48
|
+
pass
|
49
|
+
|
50
|
+
return False
|
51
|
+
|
52
|
+
def __init__(self, config, scoring, models):
|
53
|
+
# Check before parent constructor since it calls loadmodel
|
54
|
+
if not MODEL2VEC:
|
55
|
+
raise ImportError('Model2Vec is not available - install "vectors" extra to enable')
|
56
|
+
|
57
|
+
super().__init__(config, scoring, models)
|
58
|
+
|
59
|
+
def loadmodel(self, path):
|
60
|
+
return StaticModel.from_pretrained(path)
|
61
|
+
|
62
|
+
def encode(self, data, category=None):
|
63
|
+
# Additional model arguments
|
64
|
+
modelargs = self.config.get("vectors", {})
|
65
|
+
|
66
|
+
# Encode data
|
67
|
+
return self.model.encode(data, batch_size=self.encodebatch, **modelargs)
|
@@ -0,0 +1,92 @@
|
|
1
|
+
"""
|
2
|
+
Sentence Transformers module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from sentence_transformers import SentenceTransformer
|
8
|
+
|
9
|
+
SENTENCE_TRANSFORMERS = True
|
10
|
+
except ImportError:
|
11
|
+
SENTENCE_TRANSFORMERS = False
|
12
|
+
|
13
|
+
from ...models import Models
|
14
|
+
|
15
|
+
from ..base import Vectors
|
16
|
+
|
17
|
+
|
18
|
+
class STVectors(Vectors):
|
19
|
+
"""
|
20
|
+
Builds vectors using sentence-transformers (aka SBERT).
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, config, scoring, models):
|
24
|
+
# Check before parent constructor since it calls loadmodel
|
25
|
+
if not SENTENCE_TRANSFORMERS:
|
26
|
+
raise ImportError('sentence-transformers is not available - install "vectors" extra to enable')
|
27
|
+
|
28
|
+
# Pool parameter created here since loadmodel is called from parent constructor
|
29
|
+
self.pool = None
|
30
|
+
|
31
|
+
super().__init__(config, scoring, models)
|
32
|
+
|
33
|
+
def loadmodel(self, path):
|
34
|
+
# Get target device
|
35
|
+
gpu, pool = self.config.get("gpu", True), False
|
36
|
+
|
37
|
+
# Default mode uses a single GPU. Setting to all spawns a process per GPU.
|
38
|
+
if isinstance(gpu, str) and gpu == "all":
|
39
|
+
# Get number of accelerator devices available
|
40
|
+
devices = Models.acceleratorcount()
|
41
|
+
|
42
|
+
# Enable multiprocessing pooling only when multiple devices are available
|
43
|
+
gpu, pool = devices <= 1, devices > 1
|
44
|
+
|
45
|
+
# Tensor device id
|
46
|
+
deviceid = Models.deviceid(gpu)
|
47
|
+
|
48
|
+
# Additional model arguments
|
49
|
+
modelargs = self.config.get("vectors", {})
|
50
|
+
|
51
|
+
# Load sentence-transformers encoder
|
52
|
+
model = self.loadencoder(path, device=Models.device(deviceid), **modelargs)
|
53
|
+
|
54
|
+
# Start process pool for multiple GPUs
|
55
|
+
if pool:
|
56
|
+
self.pool = model.start_multi_process_pool()
|
57
|
+
|
58
|
+
# Return model
|
59
|
+
return model
|
60
|
+
|
61
|
+
def encode(self, data, category=None):
|
62
|
+
# Get encode method based on input category
|
63
|
+
encode = self.model.encode_query if category == "query" else self.model.encode_document if category == "data" else self.model.encode
|
64
|
+
|
65
|
+
# Additional encoding arguments
|
66
|
+
encodeargs = self.config.get("encodeargs", {})
|
67
|
+
|
68
|
+
# Encode with sentence transformers encoder
|
69
|
+
return encode(data, pool=self.pool, batch_size=self.encodebatch, **encodeargs)
|
70
|
+
|
71
|
+
def close(self):
|
72
|
+
# Close pool before model is closed in parent method
|
73
|
+
if self.pool:
|
74
|
+
self.model.stop_multi_process_pool(self.pool)
|
75
|
+
self.pool = None
|
76
|
+
|
77
|
+
super().close()
|
78
|
+
|
79
|
+
def loadencoder(self, path, device, **kwargs):
|
80
|
+
"""
|
81
|
+
Loads the embeddings encoder model from path.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
path: model path
|
85
|
+
device: tensor device
|
86
|
+
kwargs: additional keyword args
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
embeddings encoder
|
90
|
+
"""
|
91
|
+
|
92
|
+
return SentenceTransformer(path, device=device, **kwargs)
|
@@ -0,0 +1,211 @@
|
|
1
|
+
"""
|
2
|
+
Word Vectors module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import tempfile
|
9
|
+
|
10
|
+
from multiprocessing import Pool
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
|
14
|
+
from huggingface_hub.errors import HFValidationError
|
15
|
+
from transformers.utils import cached_file
|
16
|
+
|
17
|
+
# Conditional import
|
18
|
+
try:
|
19
|
+
from staticvectors import Database, StaticVectors
|
20
|
+
|
21
|
+
STATICVECTORS = True
|
22
|
+
except ImportError:
|
23
|
+
STATICVECTORS = False
|
24
|
+
|
25
|
+
from ...pipeline import Tokenizer
|
26
|
+
|
27
|
+
from ..base import Vectors
|
28
|
+
|
29
|
+
# Logging configuration
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
# Multiprocessing helper methods
|
33
|
+
# pylint: disable=W0603
|
34
|
+
PARAMETERS, VECTORS = None, None
|
35
|
+
|
36
|
+
|
37
|
+
def create(config, scoring):
|
38
|
+
"""
|
39
|
+
Multiprocessing helper method. Creates a global embeddings object to be accessed in a new subprocess.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
config: vector configuration
|
43
|
+
scoring: scoring instance
|
44
|
+
"""
|
45
|
+
|
46
|
+
global PARAMETERS
|
47
|
+
global VECTORS
|
48
|
+
|
49
|
+
# Store model parameters for lazy loading
|
50
|
+
PARAMETERS, VECTORS = (config, scoring, None), None
|
51
|
+
|
52
|
+
|
53
|
+
def transform(document):
|
54
|
+
"""
|
55
|
+
Multiprocessing helper method. Transforms document into an embeddings vector.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
document: (id, data, tags)
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
(id, embedding)
|
62
|
+
"""
|
63
|
+
|
64
|
+
# Lazy load vectors model
|
65
|
+
global VECTORS
|
66
|
+
if not VECTORS:
|
67
|
+
VECTORS = WordVectors(*PARAMETERS)
|
68
|
+
|
69
|
+
return (document[0], VECTORS.transform(document))
|
70
|
+
|
71
|
+
|
72
|
+
class WordVectors(Vectors):
|
73
|
+
"""
|
74
|
+
Builds vectors using weighted word embeddings.
|
75
|
+
"""
|
76
|
+
|
77
|
+
@staticmethod
|
78
|
+
def ismodel(path):
|
79
|
+
"""
|
80
|
+
Checks if path is a WordVectors model.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
path: input path
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
True if this is a WordVectors model, False otherwise
|
87
|
+
"""
|
88
|
+
|
89
|
+
# Check if this is a SQLite database
|
90
|
+
if WordVectors.isdatabase(path):
|
91
|
+
return True
|
92
|
+
|
93
|
+
try:
|
94
|
+
# Download file and parse JSON
|
95
|
+
path = cached_file(path_or_repo_id=path, filename="config.json")
|
96
|
+
if path:
|
97
|
+
with open(path, encoding="utf-8") as f:
|
98
|
+
config = json.load(f)
|
99
|
+
return config.get("model_type") == "staticvectors"
|
100
|
+
|
101
|
+
# Ignore this error - invalid repo or directory
|
102
|
+
except (HFValidationError, OSError):
|
103
|
+
pass
|
104
|
+
|
105
|
+
return False
|
106
|
+
|
107
|
+
@staticmethod
|
108
|
+
def isdatabase(path):
|
109
|
+
"""
|
110
|
+
Checks if this is a SQLite database file which is the file format used for word vectors databases.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
path: path to check
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
True if this is a SQLite database
|
117
|
+
"""
|
118
|
+
|
119
|
+
return isinstance(path, str) and STATICVECTORS and Database.isdatabase(path)
|
120
|
+
|
121
|
+
def __init__(self, config, scoring, models):
|
122
|
+
# Check before parent constructor since it calls loadmodel
|
123
|
+
if not STATICVECTORS:
|
124
|
+
raise ImportError('staticvectors is not available - install "vectors" extra to enable')
|
125
|
+
|
126
|
+
super().__init__(config, scoring, models)
|
127
|
+
|
128
|
+
def loadmodel(self, path):
|
129
|
+
return StaticVectors(path)
|
130
|
+
|
131
|
+
def encode(self, data, category=None):
|
132
|
+
# Iterate over each data element, tokenize (if necessary) and build an aggregated embeddings vector
|
133
|
+
embeddings = []
|
134
|
+
for tokens in data:
|
135
|
+
# Convert to tokens, if necessary. If tokenized list is empty, use input string.
|
136
|
+
if isinstance(tokens, str):
|
137
|
+
tokenlist = Tokenizer.tokenize(tokens)
|
138
|
+
tokens = tokenlist if tokenlist else [tokens]
|
139
|
+
|
140
|
+
# Generate weights for each vector using a scoring method
|
141
|
+
weights = self.scoring.weights(tokens) if self.scoring else None
|
142
|
+
|
143
|
+
# pylint: disable=E1133
|
144
|
+
if weights and [x for x in weights if x > 0]:
|
145
|
+
# Build weighted average embeddings vector. Create weights array as float32 to match embeddings precision.
|
146
|
+
embedding = np.average(self.lookup(tokens), weights=np.array(weights, dtype=np.float32), axis=0)
|
147
|
+
else:
|
148
|
+
# If no weights, use mean
|
149
|
+
embedding = np.mean(self.lookup(tokens), axis=0)
|
150
|
+
|
151
|
+
embeddings.append(embedding)
|
152
|
+
|
153
|
+
return np.array(embeddings, dtype=np.float32)
|
154
|
+
|
155
|
+
def index(self, documents, batchsize=500, checkpoint=None):
|
156
|
+
# Derive number of parallel processes
|
157
|
+
parallel = self.config.get("parallel", True)
|
158
|
+
parallel = os.cpu_count() if parallel and isinstance(parallel, bool) else int(parallel)
|
159
|
+
|
160
|
+
# Use default single process indexing logic
|
161
|
+
if not parallel:
|
162
|
+
return super().index(documents, batchsize)
|
163
|
+
|
164
|
+
# Customize indexing logic with multiprocessing pool to efficiently build vectors
|
165
|
+
ids, dimensions, batches, stream = [], None, 0, None
|
166
|
+
|
167
|
+
# Shared objects with Pool
|
168
|
+
args = (self.config, self.scoring)
|
169
|
+
|
170
|
+
# Convert all documents to embedding arrays, stream embeddings to disk to control memory usage
|
171
|
+
with Pool(parallel, initializer=create, initargs=args) as pool:
|
172
|
+
with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy", delete=False) as output:
|
173
|
+
stream = output.name
|
174
|
+
embeddings = []
|
175
|
+
for uid, embedding in pool.imap(transform, documents, self.encodebatch):
|
176
|
+
if not dimensions:
|
177
|
+
# Set number of dimensions for embeddings
|
178
|
+
dimensions = embedding.shape[0]
|
179
|
+
|
180
|
+
ids.append(uid)
|
181
|
+
embeddings.append(embedding)
|
182
|
+
|
183
|
+
if len(embeddings) == batchsize:
|
184
|
+
np.save(output, np.array(embeddings, dtype=np.float32), allow_pickle=False)
|
185
|
+
batches += 1
|
186
|
+
|
187
|
+
embeddings = []
|
188
|
+
|
189
|
+
# Final embeddings batch
|
190
|
+
if embeddings:
|
191
|
+
np.save(output, np.array(embeddings, dtype=np.float32), allow_pickle=False)
|
192
|
+
batches += 1
|
193
|
+
|
194
|
+
return (ids, dimensions, batches, stream)
|
195
|
+
|
196
|
+
def lookup(self, tokens):
|
197
|
+
"""
|
198
|
+
Queries word vectors for given list of input tokens.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
tokens: list of tokens to query
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
word vectors array
|
205
|
+
"""
|
206
|
+
|
207
|
+
return self.model.embeddings(tokens)
|
208
|
+
|
209
|
+
def tokens(self, data):
|
210
|
+
# Skip tokenization rules
|
211
|
+
return data
|