mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
txtai/archive/base.py
ADDED
@@ -0,0 +1,104 @@
|
|
1
|
+
"""
|
2
|
+
Archive module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
from tempfile import TemporaryDirectory
|
8
|
+
|
9
|
+
from .tar import Tar
|
10
|
+
from .zip import Zip
|
11
|
+
|
12
|
+
|
13
|
+
class Archive:
|
14
|
+
"""
|
15
|
+
Base class for archive instances.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, directory=None):
|
19
|
+
"""
|
20
|
+
Creates a new archive instance.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
directory: directory to use as working directory, defaults to a temporary directory
|
24
|
+
"""
|
25
|
+
|
26
|
+
self.directory = directory
|
27
|
+
|
28
|
+
def isarchive(self, path):
|
29
|
+
"""
|
30
|
+
Checks if path is an archive file based on the extension.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
path: path to check
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
True if the path ends with an archive extension, False otherwise
|
37
|
+
"""
|
38
|
+
|
39
|
+
return path and any(path.lower().endswith(extension) for extension in [".tar.bz2", ".tar.gz", ".tar.xz", ".zip"])
|
40
|
+
|
41
|
+
def path(self):
|
42
|
+
"""
|
43
|
+
Gets the current working directory for this archive instance.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
archive working directory
|
47
|
+
"""
|
48
|
+
|
49
|
+
# Default to a temporary directory. All files created in this directory will be deleted
|
50
|
+
# when this archive instance goes out of scope.
|
51
|
+
if not self.directory:
|
52
|
+
# pylint: disable=R1732
|
53
|
+
self.directory = TemporaryDirectory()
|
54
|
+
|
55
|
+
return self.directory.name if isinstance(self.directory, TemporaryDirectory) else self.directory
|
56
|
+
|
57
|
+
def load(self, path, compression=None):
|
58
|
+
"""
|
59
|
+
Extracts file at path to archive working directory.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
path: path to archive file
|
63
|
+
compression: compression format, infers from path if not provided
|
64
|
+
"""
|
65
|
+
|
66
|
+
# Unpack compressed file
|
67
|
+
compress = self.create(path, compression)
|
68
|
+
compress.unpack(path, self.path())
|
69
|
+
|
70
|
+
def save(self, path, compression=None):
|
71
|
+
"""
|
72
|
+
Archives files in archive working directory to file at path.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
path: path to archive file
|
76
|
+
compression: compression format, infers from path if not provided
|
77
|
+
"""
|
78
|
+
|
79
|
+
# Create output directory, if necessary
|
80
|
+
output = os.path.dirname(path)
|
81
|
+
if output:
|
82
|
+
os.makedirs(output, exist_ok=True)
|
83
|
+
|
84
|
+
# Pack into compressed file
|
85
|
+
compress = self.create(path, compression)
|
86
|
+
compress.pack(self.path(), path)
|
87
|
+
|
88
|
+
def create(self, path, compression):
|
89
|
+
"""
|
90
|
+
Method to construct a Compress instance.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
path: file path
|
94
|
+
compression: compression format, infers using file extension if not provided
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
Compress
|
98
|
+
"""
|
99
|
+
|
100
|
+
# Infer compression format from path if not provided
|
101
|
+
compression = compression if compression else path.lower().split(".")[-1]
|
102
|
+
|
103
|
+
# Create compression instance
|
104
|
+
return Zip() if compression == "zip" else Tar()
|
@@ -0,0 +1,51 @@
|
|
1
|
+
"""
|
2
|
+
Compress module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
|
8
|
+
class Compress:
|
9
|
+
"""
|
10
|
+
Base class for Compress instances.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def pack(self, path, output):
|
14
|
+
"""
|
15
|
+
Compresses files in directory path to file output.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
path: input directory path
|
19
|
+
output: output file
|
20
|
+
"""
|
21
|
+
|
22
|
+
raise NotImplementedError
|
23
|
+
|
24
|
+
def unpack(self, path, output):
|
25
|
+
"""
|
26
|
+
Extracts all files in path to output.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
path: input file path
|
30
|
+
output: output directory
|
31
|
+
"""
|
32
|
+
|
33
|
+
raise NotImplementedError
|
34
|
+
|
35
|
+
def validate(self, directory, path):
|
36
|
+
"""
|
37
|
+
Validates path is under directory.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
directory: base directory
|
41
|
+
path: path to validate
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
True if path is under directory, False otherwise
|
45
|
+
"""
|
46
|
+
|
47
|
+
directory = os.path.abspath(directory)
|
48
|
+
path = os.path.abspath(path)
|
49
|
+
prefix = os.path.commonprefix([directory, path])
|
50
|
+
|
51
|
+
return prefix == directory
|
txtai/archive/factory.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Archive
|
6
|
+
|
7
|
+
|
8
|
+
class ArchiveFactory:
|
9
|
+
"""
|
10
|
+
Methods to create Archive instances.
|
11
|
+
"""
|
12
|
+
|
13
|
+
@staticmethod
|
14
|
+
def create(directory=None):
|
15
|
+
"""
|
16
|
+
Create a new Archive instance.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
directory: optional default working directory, otherwise uses a temporary directory
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
Archive
|
23
|
+
"""
|
24
|
+
|
25
|
+
return Archive(directory)
|
txtai/archive/tar.py
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
"""
|
2
|
+
Tar module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import tarfile
|
7
|
+
|
8
|
+
from .compress import Compress
|
9
|
+
|
10
|
+
|
11
|
+
class Tar(Compress):
|
12
|
+
"""
|
13
|
+
Tar compression
|
14
|
+
"""
|
15
|
+
|
16
|
+
def pack(self, path, output):
|
17
|
+
# Infer compression type
|
18
|
+
compression = self.compression(output)
|
19
|
+
|
20
|
+
with tarfile.open(output, f"w:{compression}" if compression else "w") as tar:
|
21
|
+
tar.add(path, arcname=".")
|
22
|
+
|
23
|
+
def unpack(self, path, output):
|
24
|
+
# Infer compression type
|
25
|
+
compression = self.compression(path)
|
26
|
+
|
27
|
+
with tarfile.open(path, f"r:{compression}" if compression else "r") as tar:
|
28
|
+
# Validate paths
|
29
|
+
for member in tar.getmembers():
|
30
|
+
fullpath = os.path.join(path, member.name)
|
31
|
+
if not self.validate(path, fullpath):
|
32
|
+
raise IOError(f"Invalid tar entry: {member.name}")
|
33
|
+
|
34
|
+
tar.extractall(output)
|
35
|
+
|
36
|
+
def compression(self, path):
|
37
|
+
"""
|
38
|
+
Gets compression type for path.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
path: path to file
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
compression type
|
45
|
+
"""
|
46
|
+
|
47
|
+
# Infer compression type from last path component. Limit to supported types.
|
48
|
+
compression = path.lower().split(".")[-1]
|
49
|
+
return compression if compression in ("bz2", "gz", "xz") else None
|
txtai/archive/zip.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
"""
|
2
|
+
Zip module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
from zipfile import ZipFile, ZIP_DEFLATED
|
8
|
+
|
9
|
+
from .compress import Compress
|
10
|
+
|
11
|
+
|
12
|
+
class Zip(Compress):
|
13
|
+
"""
|
14
|
+
Zip compression
|
15
|
+
"""
|
16
|
+
|
17
|
+
def pack(self, path, output):
|
18
|
+
with ZipFile(output, "w", ZIP_DEFLATED) as zfile:
|
19
|
+
for root, _, files in sorted(os.walk(path)):
|
20
|
+
for f in files:
|
21
|
+
# Generate archive name with relative path, if necessary
|
22
|
+
name = os.path.join(os.path.relpath(root, path), f)
|
23
|
+
|
24
|
+
# Write file to zip
|
25
|
+
zfile.write(os.path.join(root, f), arcname=name)
|
26
|
+
|
27
|
+
def unpack(self, path, output):
|
28
|
+
with ZipFile(path, "r") as zfile:
|
29
|
+
# Validate path if directory specified
|
30
|
+
for fullpath in zfile.namelist():
|
31
|
+
fullpath = os.path.join(path, fullpath)
|
32
|
+
if os.path.dirname(fullpath) and not self.validate(path, fullpath):
|
33
|
+
raise IOError(f"Invalid zip entry: {fullpath}")
|
34
|
+
|
35
|
+
zfile.extractall(output)
|
txtai/cloud/__init__.py
ADDED
txtai/cloud/base.py
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
"""
|
2
|
+
Cloud module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
from ..archive import ArchiveFactory
|
8
|
+
|
9
|
+
|
10
|
+
class Cloud:
|
11
|
+
"""
|
12
|
+
Base class for cloud providers. Cloud providers sync content between local and remote storage.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, config):
|
16
|
+
"""
|
17
|
+
Creates a new cloud connection.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
config: cloud configuration
|
21
|
+
"""
|
22
|
+
|
23
|
+
self.config = config
|
24
|
+
|
25
|
+
def exists(self, path=None):
|
26
|
+
"""
|
27
|
+
Checks if path exists in cloud. If path is None, this method checks if the container exists.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
path: path to check
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
True if path or container exists, False otherwise
|
34
|
+
"""
|
35
|
+
|
36
|
+
return self.metadata(path) is not None
|
37
|
+
|
38
|
+
def metadata(self, path=None):
|
39
|
+
"""
|
40
|
+
Returns metadata for path from cloud. If path is None, this method returns metadata
|
41
|
+
for container.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
path: retrieve metadata for this path
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
path or container metadata if available, otherwise returns None
|
48
|
+
"""
|
49
|
+
|
50
|
+
raise NotImplementedError
|
51
|
+
|
52
|
+
def load(self, path=None):
|
53
|
+
"""
|
54
|
+
Retrieves content from cloud and stores locally. If path is empty, this method retrieves
|
55
|
+
all content in the container.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
path: path to retrieve
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
local path which can be different than input path
|
62
|
+
"""
|
63
|
+
|
64
|
+
raise NotImplementedError
|
65
|
+
|
66
|
+
def save(self, path):
|
67
|
+
"""
|
68
|
+
Sends local content stored in path to cloud.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
path: local path to sync
|
72
|
+
"""
|
73
|
+
|
74
|
+
raise NotImplementedError
|
75
|
+
|
76
|
+
def isarchive(self, path):
|
77
|
+
"""
|
78
|
+
Check if path is an archive file.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
path: path to check
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
True if path ends with an archive extension, false otherwise
|
85
|
+
"""
|
86
|
+
|
87
|
+
return ArchiveFactory.create().isarchive(path)
|
88
|
+
|
89
|
+
def listfiles(self, path):
|
90
|
+
"""
|
91
|
+
Lists files in path. If path is a file, this method returns a single element list
|
92
|
+
containing path.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
path: path to list
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
List of files
|
99
|
+
"""
|
100
|
+
|
101
|
+
# List all files if path is a directory
|
102
|
+
if os.path.isdir(path):
|
103
|
+
return [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
|
104
|
+
|
105
|
+
# Path is a file
|
106
|
+
return [path]
|
txtai/cloud/factory.py
ADDED
@@ -0,0 +1,70 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..util import Resolver
|
6
|
+
|
7
|
+
from .hub import HuggingFaceHub
|
8
|
+
from .storage import ObjectStorage, LIBCLOUD
|
9
|
+
|
10
|
+
|
11
|
+
class CloudFactory:
|
12
|
+
"""
|
13
|
+
Methods to create Cloud instances.
|
14
|
+
"""
|
15
|
+
|
16
|
+
@staticmethod
|
17
|
+
def create(config):
|
18
|
+
"""
|
19
|
+
Creates a Cloud instance.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config: cloud configuration
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
Cloud
|
26
|
+
"""
|
27
|
+
|
28
|
+
# Cloud instance
|
29
|
+
cloud = None
|
30
|
+
|
31
|
+
provider = config.get("provider", "")
|
32
|
+
|
33
|
+
# Hugging Face Hub
|
34
|
+
if provider.lower() == "huggingface-hub":
|
35
|
+
cloud = HuggingFaceHub(config)
|
36
|
+
|
37
|
+
# Cloud object storage
|
38
|
+
elif ObjectStorage.isprovider(provider):
|
39
|
+
cloud = ObjectStorage(config)
|
40
|
+
|
41
|
+
# External provider
|
42
|
+
elif provider:
|
43
|
+
cloud = CloudFactory.resolve(provider, config)
|
44
|
+
|
45
|
+
return cloud
|
46
|
+
|
47
|
+
@staticmethod
|
48
|
+
def resolve(backend, config):
|
49
|
+
"""
|
50
|
+
Attempt to resolve a custom cloud backend.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
backend: backend class
|
54
|
+
config: configuration parameters
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Cloud
|
58
|
+
"""
|
59
|
+
|
60
|
+
try:
|
61
|
+
return Resolver()(backend)(config)
|
62
|
+
|
63
|
+
except Exception as e:
|
64
|
+
# Failure message
|
65
|
+
message = f'Unable to resolve cloud backend: "{backend}".'
|
66
|
+
|
67
|
+
# Append message if LIBCLOUD is not installed
|
68
|
+
message += ' Cloud storage is not available - install "cloud" extra to enable' if not LIBCLOUD else ""
|
69
|
+
|
70
|
+
raise ImportError(message) from e
|
txtai/cloud/hub.py
ADDED
@@ -0,0 +1,101 @@
|
|
1
|
+
"""
|
2
|
+
Hugging Face Hub module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import tempfile
|
7
|
+
|
8
|
+
import huggingface_hub
|
9
|
+
|
10
|
+
from huggingface_hub.utils import RepositoryNotFoundError
|
11
|
+
|
12
|
+
from .base import Cloud
|
13
|
+
|
14
|
+
|
15
|
+
class HuggingFaceHub(Cloud):
|
16
|
+
"""
|
17
|
+
Hugging Face Hub cloud provider.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def metadata(self, path=None):
|
21
|
+
try:
|
22
|
+
# If this is an archive file path, get file metadata
|
23
|
+
if self.isarchive(path):
|
24
|
+
url = huggingface_hub.hf_hub_url(
|
25
|
+
repo_id=self.config["container"], filename=os.path.basename(path), revision=self.config.get("revision")
|
26
|
+
)
|
27
|
+
|
28
|
+
return huggingface_hub.get_hf_file_metadata(url=url, token=self.config.get("token"))
|
29
|
+
|
30
|
+
# Otherwise return repository metadata
|
31
|
+
return huggingface_hub.model_info(repo_id=self.config["container"], revision=self.config.get("revision"), token=self.config.get("token"))
|
32
|
+
|
33
|
+
except RepositoryNotFoundError:
|
34
|
+
return None
|
35
|
+
|
36
|
+
def load(self, path=None):
|
37
|
+
# Download archvie file and return local path
|
38
|
+
if self.isarchive(path):
|
39
|
+
return huggingface_hub.hf_hub_download(
|
40
|
+
repo_id=self.config["container"],
|
41
|
+
filename=os.path.basename(path),
|
42
|
+
revision=self.config.get("revision"),
|
43
|
+
cache_dir=self.config.get("cache"),
|
44
|
+
token=self.config.get("token"),
|
45
|
+
)
|
46
|
+
|
47
|
+
# Download repository and return cached path
|
48
|
+
return huggingface_hub.snapshot_download(
|
49
|
+
repo_id=self.config["container"], revision=self.config.get("revision"), cache_dir=self.config.get("cache"), token=self.config.get("token")
|
50
|
+
)
|
51
|
+
|
52
|
+
def save(self, path):
|
53
|
+
# Get or create repository
|
54
|
+
huggingface_hub.create_repo(
|
55
|
+
repo_id=self.config["container"], token=self.config.get("token"), private=self.config.get("private", True), exist_ok=True
|
56
|
+
)
|
57
|
+
|
58
|
+
# Enable lfs-tracking of embeddings index files
|
59
|
+
self.lfstrack()
|
60
|
+
|
61
|
+
# Upload files
|
62
|
+
for f in self.listfiles(path):
|
63
|
+
huggingface_hub.upload_file(
|
64
|
+
repo_id=self.config["container"],
|
65
|
+
revision=self.config.get("revision"),
|
66
|
+
token=self.config.get("token"),
|
67
|
+
path_or_fileobj=f,
|
68
|
+
path_in_repo=os.path.basename(f),
|
69
|
+
)
|
70
|
+
|
71
|
+
def lfstrack(self):
|
72
|
+
"""
|
73
|
+
Adds lfs-tracking of embeddings index files. This method adds tracking for documents and embeddings to .gitattributes.
|
74
|
+
"""
|
75
|
+
|
76
|
+
# Get and read .gitattributes file
|
77
|
+
path = huggingface_hub.hf_hub_download(
|
78
|
+
repo_id=self.config["container"], filename=os.path.basename(".gitattributes"), token=self.config.get("token")
|
79
|
+
)
|
80
|
+
|
81
|
+
with open(path, "r", encoding="utf-8") as f:
|
82
|
+
content = f.read()
|
83
|
+
|
84
|
+
# Check if index files are lfs-tracked. Update .gitattributes, if necessary.
|
85
|
+
if "embeddings " not in content:
|
86
|
+
# Add documents and embeddings to lfs tracking
|
87
|
+
content += "documents filter=lfs diff=lfs merge=lfs -text\n"
|
88
|
+
content += "embeddings filter=lfs diff=lfs merge=lfs -text\n"
|
89
|
+
|
90
|
+
# pylint: disable=R1732
|
91
|
+
with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
|
92
|
+
tmp.write(content)
|
93
|
+
attributes = tmp.name
|
94
|
+
|
95
|
+
# Upload file
|
96
|
+
huggingface_hub.upload_file(
|
97
|
+
repo_id=self.config["container"], token=self.config.get("token"), path_or_fileobj=attributes, path_in_repo=os.path.basename(path)
|
98
|
+
)
|
99
|
+
|
100
|
+
# Remove temporary file
|
101
|
+
os.remove(attributes)
|
txtai/cloud/storage.py
ADDED
@@ -0,0 +1,125 @@
|
|
1
|
+
"""
|
2
|
+
Object storage module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
# Conditional import
|
8
|
+
try:
|
9
|
+
from libcloud.storage.providers import get_driver, DRIVERS
|
10
|
+
from libcloud.storage.types import ContainerDoesNotExistError, ObjectDoesNotExistError
|
11
|
+
|
12
|
+
LIBCLOUD = True
|
13
|
+
except ImportError:
|
14
|
+
LIBCLOUD, DRIVERS = False, None
|
15
|
+
|
16
|
+
|
17
|
+
from .base import Cloud
|
18
|
+
|
19
|
+
|
20
|
+
class ObjectStorage(Cloud):
|
21
|
+
"""
|
22
|
+
Object storage cloud provider backed by Apache libcloud.
|
23
|
+
"""
|
24
|
+
|
25
|
+
@staticmethod
|
26
|
+
def isprovider(provider):
|
27
|
+
"""
|
28
|
+
Checks if this provider is an object storage provider.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
provider: provider name
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
True if this is an object storage provider
|
35
|
+
"""
|
36
|
+
|
37
|
+
return LIBCLOUD and provider and provider.lower() in [x.lower() for x in DRIVERS]
|
38
|
+
|
39
|
+
def __init__(self, config):
|
40
|
+
super().__init__(config)
|
41
|
+
|
42
|
+
if not LIBCLOUD:
|
43
|
+
raise ImportError('Cloud object storage is not available - install "cloud" extra to enable')
|
44
|
+
|
45
|
+
# Get driver for provider
|
46
|
+
driver = get_driver(config["provider"])
|
47
|
+
|
48
|
+
# Get client connection
|
49
|
+
self.client = driver(
|
50
|
+
config.get("key", os.environ.get("ACCESS_KEY")),
|
51
|
+
config.get("secret", os.environ.get("ACCESS_SECRET")),
|
52
|
+
**{field: config.get(field) for field in ["host", "port", "region", "token"] if config.get(field)},
|
53
|
+
)
|
54
|
+
|
55
|
+
def metadata(self, path=None):
|
56
|
+
try:
|
57
|
+
# If this is an archive path, check if file exists
|
58
|
+
if self.isarchive(path):
|
59
|
+
return self.client.get_object(self.config["container"], self.objectname(path))
|
60
|
+
|
61
|
+
# Otherwise check if container exists
|
62
|
+
return self.client.get_container(self.config["container"])
|
63
|
+
except (ContainerDoesNotExistError, ObjectDoesNotExistError):
|
64
|
+
return None
|
65
|
+
|
66
|
+
def load(self, path=None):
|
67
|
+
# Download archive file
|
68
|
+
if self.isarchive(path):
|
69
|
+
obj = self.client.get_object(self.config["container"], self.objectname(path))
|
70
|
+
|
71
|
+
# Create local directory, if necessary
|
72
|
+
directory = os.path.dirname(path)
|
73
|
+
if directory:
|
74
|
+
os.makedirs(directory, exist_ok=True)
|
75
|
+
|
76
|
+
obj.download(path, overwrite_existing=True)
|
77
|
+
|
78
|
+
# Download files in container. Optionally filter with a provided prefix.
|
79
|
+
else:
|
80
|
+
container = self.client.get_container(self.config["container"])
|
81
|
+
for obj in container.list_objects(prefix=self.config.get("prefix")):
|
82
|
+
# Derive local path and directory
|
83
|
+
localpath = os.path.join(path, obj.name)
|
84
|
+
directory = os.path.dirname(localpath)
|
85
|
+
|
86
|
+
# Create local directory, if necessary
|
87
|
+
os.makedirs(directory, exist_ok=True)
|
88
|
+
|
89
|
+
# Download file locally
|
90
|
+
obj.download(localpath, overwrite_existing=True)
|
91
|
+
|
92
|
+
return path
|
93
|
+
|
94
|
+
def save(self, path):
|
95
|
+
# Get or create container
|
96
|
+
try:
|
97
|
+
container = self.client.get_container(self.config["container"])
|
98
|
+
except ContainerDoesNotExistError:
|
99
|
+
container = self.client.create_container(self.config["container"])
|
100
|
+
|
101
|
+
# Upload files
|
102
|
+
for f in self.listfiles(path):
|
103
|
+
with open(f, "rb") as iterator:
|
104
|
+
self.client.upload_object_via_stream(iterator=iterator, container=container, object_name=self.objectname(f))
|
105
|
+
|
106
|
+
def objectname(self, name):
|
107
|
+
"""
|
108
|
+
Derives an object name. This method checks if a prefix configuration parameter is present and combines
|
109
|
+
it with the input name parameter.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
name: input name
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
object name
|
116
|
+
"""
|
117
|
+
|
118
|
+
# Get base name
|
119
|
+
name = os.path.basename(name)
|
120
|
+
|
121
|
+
# Get optional prefix/folder
|
122
|
+
prefix = self.config.get("prefix")
|
123
|
+
|
124
|
+
# Prepend prefix, if applicable
|
125
|
+
return f"{prefix}/{name}" if prefix else name
|