mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,57 @@
|
|
1
|
+
"""
|
2
|
+
Recovery module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
|
8
|
+
|
9
|
+
class Recovery:
|
10
|
+
"""
|
11
|
+
Vector embeddings recovery. This class handles streaming embeddings from a vector checkpoint file.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, checkpoint, vectorsid, load):
|
15
|
+
"""
|
16
|
+
Creates a Recovery instance.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
checkpoint: checkpoint directory
|
20
|
+
vectorsid: vectors uid for current configuration
|
21
|
+
load: load embeddings method
|
22
|
+
"""
|
23
|
+
|
24
|
+
self.spool, self.path, self.load = None, None, load
|
25
|
+
|
26
|
+
# Get unique file id
|
27
|
+
path = f"{checkpoint}/{vectorsid}"
|
28
|
+
if os.path.exists(path):
|
29
|
+
# Generate recovery path
|
30
|
+
self.path = f"{checkpoint}/recovery"
|
31
|
+
|
32
|
+
# Copy current checkpoint to recovery
|
33
|
+
shutil.copyfile(path, self.path)
|
34
|
+
|
35
|
+
# Open file an return
|
36
|
+
# pylint: disable=R1732
|
37
|
+
self.spool = open(self.path, "rb")
|
38
|
+
|
39
|
+
def __call__(self):
|
40
|
+
"""
|
41
|
+
Reads and returns the next batch of embeddings.
|
42
|
+
|
43
|
+
Returns
|
44
|
+
batch of embeddings
|
45
|
+
"""
|
46
|
+
|
47
|
+
try:
|
48
|
+
return self.load(self.spool) if self.spool else None
|
49
|
+
except EOFError:
|
50
|
+
# End of spool file, cleanup
|
51
|
+
self.spool.close()
|
52
|
+
os.remove(self.path)
|
53
|
+
|
54
|
+
# Clear parameters
|
55
|
+
self.spool, self.path = None, None
|
56
|
+
|
57
|
+
return None
|
@@ -0,0 +1,90 @@
|
|
1
|
+
"""
|
2
|
+
SparseVectors module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from scipy.sparse import csr_matrix, vstack
|
8
|
+
from sklearn.preprocessing import normalize
|
9
|
+
from sklearn.utils.extmath import safe_sparse_dot
|
10
|
+
|
11
|
+
SPARSE = True
|
12
|
+
except ImportError:
|
13
|
+
SPARSE = False
|
14
|
+
|
15
|
+
from ...util import SparseArray
|
16
|
+
from ..base import Vectors
|
17
|
+
|
18
|
+
|
19
|
+
# pylint: disable=W0223
|
20
|
+
class SparseVectors(Vectors):
|
21
|
+
"""
|
22
|
+
Base class for sparse vector models. Vector models transform input content into sparse arrays.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, config, scoring, models):
|
26
|
+
# Check before parent constructor since it calls loadmodel
|
27
|
+
if not SPARSE:
|
28
|
+
raise ImportError('SparseVectors is not available - install "vectors" extra to enable')
|
29
|
+
|
30
|
+
super().__init__(config, scoring, models)
|
31
|
+
|
32
|
+
# Get normalization setting
|
33
|
+
self.isnormalize = self.config.get("normalize", self.defaultnormalize()) if self.config else None
|
34
|
+
|
35
|
+
def encode(self, data, category=None):
|
36
|
+
# Encode data to embeddings
|
37
|
+
embeddings = super().encode(data, category)
|
38
|
+
|
39
|
+
# Get sparse torch vector attributes
|
40
|
+
embeddings = embeddings.cpu().coalesce()
|
41
|
+
indices = embeddings.indices().numpy()
|
42
|
+
values = embeddings.values().numpy()
|
43
|
+
|
44
|
+
# Return as SciPy CSR Matrix
|
45
|
+
return csr_matrix((values, indices), shape=embeddings.size())
|
46
|
+
|
47
|
+
def vectors(self, documents, batchsize=500, checkpoint=None, buffer=None, dtype=None):
|
48
|
+
# Run indexing
|
49
|
+
ids, dimensions, batches, stream = self.index(documents, batchsize, checkpoint)
|
50
|
+
|
51
|
+
# Rebuild sparse array
|
52
|
+
embeddings = None
|
53
|
+
with open(stream, "rb") as queue:
|
54
|
+
for _ in range(batches):
|
55
|
+
# Read in array batch
|
56
|
+
data = self.loadembeddings(queue)
|
57
|
+
embeddings = vstack((embeddings, data)) if embeddings is not None else data
|
58
|
+
|
59
|
+
# Return sparse array
|
60
|
+
return (ids, dimensions, embeddings)
|
61
|
+
|
62
|
+
def dot(self, queries, data):
|
63
|
+
return safe_sparse_dot(queries, data.T, dense_output=True).tolist()
|
64
|
+
|
65
|
+
def loadembeddings(self, f):
|
66
|
+
return SparseArray().load(f)
|
67
|
+
|
68
|
+
def saveembeddings(self, f, embeddings):
|
69
|
+
SparseArray().save(f, embeddings)
|
70
|
+
|
71
|
+
def truncate(self, embeddings):
|
72
|
+
raise ValueError("Truncate is not supported for sparse vectors")
|
73
|
+
|
74
|
+
def normalize(self, embeddings):
|
75
|
+
# Optionally normalize embeddings using method that supports sparse vectors
|
76
|
+
return normalize(embeddings, copy=False) if self.isnormalize else embeddings
|
77
|
+
|
78
|
+
def quantize(self, embeddings):
|
79
|
+
raise ValueError("Quantize is not supported for sparse vectors")
|
80
|
+
|
81
|
+
def defaultnormalize(self):
|
82
|
+
"""
|
83
|
+
Returns the default normalization setting.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
default normalization setting
|
87
|
+
"""
|
88
|
+
|
89
|
+
# Sparse vector embeddings typically perform better as unnormalized
|
90
|
+
return False
|
@@ -0,0 +1,55 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ...util import Resolver
|
6
|
+
|
7
|
+
from .sbert import SparseSTVectors
|
8
|
+
|
9
|
+
|
10
|
+
class SparseVectorsFactory:
|
11
|
+
"""
|
12
|
+
Methods to create sparse vector models.
|
13
|
+
"""
|
14
|
+
|
15
|
+
@staticmethod
|
16
|
+
def create(config, models=None):
|
17
|
+
"""
|
18
|
+
Create a Vectors model instance.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
config: vector configuration
|
22
|
+
models: models cache
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
Vectors
|
26
|
+
"""
|
27
|
+
|
28
|
+
# Get vector method
|
29
|
+
method = config.get("method", "sentence-transformers")
|
30
|
+
|
31
|
+
# Sentence Transformers vectors
|
32
|
+
if method == "sentence-transformers":
|
33
|
+
return SparseSTVectors(config, None, models) if config and config.get("path") else None
|
34
|
+
|
35
|
+
# Resolve custom method
|
36
|
+
return SparseVectorsFactory.resolve(method, config, models) if method else None
|
37
|
+
|
38
|
+
@staticmethod
|
39
|
+
def resolve(backend, config, models):
|
40
|
+
"""
|
41
|
+
Attempt to resolve a custom backend.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
backend: backend class
|
45
|
+
config: vector configuration
|
46
|
+
models: models cache
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
Vectors
|
50
|
+
"""
|
51
|
+
|
52
|
+
try:
|
53
|
+
return Resolver()(backend)(config, None, models)
|
54
|
+
except Exception as e:
|
55
|
+
raise ImportError(f"Unable to resolve sparse vectors backend: '{backend}'") from e
|
@@ -0,0 +1,34 @@
|
|
1
|
+
"""
|
2
|
+
Sparse Sentence Transformers module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from sentence_transformers import SparseEncoder
|
8
|
+
|
9
|
+
SENTENCE_TRANSFORMERS = True
|
10
|
+
except ImportError:
|
11
|
+
SENTENCE_TRANSFORMERS = False
|
12
|
+
|
13
|
+
from ..dense.sbert import STVectors
|
14
|
+
from .base import SparseVectors
|
15
|
+
|
16
|
+
|
17
|
+
class SparseSTVectors(SparseVectors, STVectors):
|
18
|
+
"""
|
19
|
+
Builds sparse vectors using sentence-transformers (aka SBERT).
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, config, scoring, models):
|
23
|
+
# Check before parent constructor since it calls loadmodel
|
24
|
+
if not SENTENCE_TRANSFORMERS:
|
25
|
+
raise ImportError('sentence-transformers is not available - install "vectors" extra to enable')
|
26
|
+
|
27
|
+
super().__init__(config, scoring, models)
|
28
|
+
|
29
|
+
def loadencoder(self, path, device, **kwargs):
|
30
|
+
return SparseEncoder(path, device=device, **kwargs)
|
31
|
+
|
32
|
+
def defaultnormalize(self):
|
33
|
+
# Enable normalization by default if similarity function is cosine
|
34
|
+
return self.model and self.model.similarity_fn_name == "cosine"
|
txtai/version.py
ADDED
txtai/workflow/base.py
ADDED
@@ -0,0 +1,184 @@
|
|
1
|
+
"""
|
2
|
+
Workflow module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import time
|
7
|
+
import traceback
|
8
|
+
|
9
|
+
from datetime import datetime
|
10
|
+
|
11
|
+
# Conditional import
|
12
|
+
try:
|
13
|
+
from croniter import croniter
|
14
|
+
|
15
|
+
CRONITER = True
|
16
|
+
except ImportError:
|
17
|
+
CRONITER = False
|
18
|
+
|
19
|
+
from .execute import Execute
|
20
|
+
|
21
|
+
# Logging configuration
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
class Workflow:
|
26
|
+
"""
|
27
|
+
Base class for all workflows.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, tasks, batch=100, workers=None, name=None, stream=None):
|
31
|
+
"""
|
32
|
+
Creates a new workflow. Workflows are lists of tasks to execute.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
tasks: list of workflow tasks
|
36
|
+
batch: how many items to process at a time, defaults to 100
|
37
|
+
workers: number of concurrent workers
|
38
|
+
name: workflow name
|
39
|
+
stream: workflow stream processor
|
40
|
+
"""
|
41
|
+
|
42
|
+
self.tasks = tasks
|
43
|
+
self.batch = batch
|
44
|
+
self.workers = workers
|
45
|
+
self.name = name
|
46
|
+
self.stream = stream
|
47
|
+
|
48
|
+
# Set default number of executor workers to max number of actions in a task
|
49
|
+
self.workers = max(len(task.action) for task in self.tasks) if not self.workers else self.workers
|
50
|
+
|
51
|
+
def __call__(self, elements):
|
52
|
+
"""
|
53
|
+
Executes a workflow for input elements. This method returns a generator that yields transformed
|
54
|
+
data elements.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
elements: iterable data elements
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
generator that yields transformed data elements
|
61
|
+
"""
|
62
|
+
|
63
|
+
# Create execute instance for this run
|
64
|
+
with Execute(self.workers) as executor:
|
65
|
+
# Run task initializers
|
66
|
+
self.initialize()
|
67
|
+
|
68
|
+
# Process elements with stream processor, if available
|
69
|
+
elements = self.stream(elements) if self.stream else elements
|
70
|
+
|
71
|
+
# Process elements in batches
|
72
|
+
for batch in self.chunk(elements):
|
73
|
+
yield from self.process(batch, executor)
|
74
|
+
|
75
|
+
# Run task finalizers
|
76
|
+
self.finalize()
|
77
|
+
|
78
|
+
def schedule(self, cron, elements, iterations=None):
|
79
|
+
"""
|
80
|
+
Schedules a workflow using a cron expression and elements.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
cron: cron expression
|
84
|
+
elements: iterable data elements passed to workflow each call
|
85
|
+
iterations: number of times to run workflow, defaults to run indefinitely
|
86
|
+
"""
|
87
|
+
|
88
|
+
# Check that croniter is installed
|
89
|
+
if not CRONITER:
|
90
|
+
raise ImportError('Workflow scheduling is not available - install "workflow" extra to enable')
|
91
|
+
|
92
|
+
logger.info("'%s' scheduler started with schedule %s", self.name, cron)
|
93
|
+
|
94
|
+
maxiterations = iterations
|
95
|
+
while iterations is None or iterations > 0:
|
96
|
+
# Schedule using localtime
|
97
|
+
schedule = croniter(cron, datetime.now().astimezone()).get_next(datetime)
|
98
|
+
logger.info("'%s' next run scheduled for %s", self.name, schedule.isoformat())
|
99
|
+
time.sleep(schedule.timestamp() - time.time())
|
100
|
+
|
101
|
+
# Run workflow
|
102
|
+
# pylint: disable=W0703
|
103
|
+
try:
|
104
|
+
for _ in self(elements):
|
105
|
+
pass
|
106
|
+
except Exception:
|
107
|
+
logger.error(traceback.format_exc())
|
108
|
+
|
109
|
+
# Decrement iterations remaining, if necessary
|
110
|
+
if iterations is not None:
|
111
|
+
iterations -= 1
|
112
|
+
|
113
|
+
logger.info("'%s' max iterations (%d) reached", self.name, maxiterations)
|
114
|
+
|
115
|
+
def initialize(self):
|
116
|
+
"""
|
117
|
+
Runs task initializer methods (if any) before processing data.
|
118
|
+
"""
|
119
|
+
|
120
|
+
# Run task initializers
|
121
|
+
for task in self.tasks:
|
122
|
+
if task.initialize:
|
123
|
+
task.initialize()
|
124
|
+
|
125
|
+
def chunk(self, elements):
|
126
|
+
"""
|
127
|
+
Splits elements into batches. This method efficiently processes both fixed size inputs and
|
128
|
+
dynamically generated inputs.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
elements: iterable data elements
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
evenly sized batches with the last batch having the remaining elements
|
135
|
+
"""
|
136
|
+
|
137
|
+
# Build batches by slicing elements, more efficient for fixed sized inputs
|
138
|
+
if hasattr(elements, "__len__") and hasattr(elements, "__getitem__"):
|
139
|
+
for x in range(0, len(elements), self.batch):
|
140
|
+
yield elements[x : x + self.batch]
|
141
|
+
|
142
|
+
# Build batches by iterating over elements when inputs are dynamically generated (i.e. generators)
|
143
|
+
else:
|
144
|
+
batch = []
|
145
|
+
for x in elements:
|
146
|
+
batch.append(x)
|
147
|
+
|
148
|
+
if len(batch) == self.batch:
|
149
|
+
yield batch
|
150
|
+
batch = []
|
151
|
+
|
152
|
+
# Final batch
|
153
|
+
if batch:
|
154
|
+
yield batch
|
155
|
+
|
156
|
+
def process(self, elements, executor):
|
157
|
+
"""
|
158
|
+
Processes a batch of data elements.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
elements: iterable data elements
|
162
|
+
executor: execute instance, enables concurrent task actions
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
transformed data elements
|
166
|
+
"""
|
167
|
+
|
168
|
+
# Run elements through each task
|
169
|
+
for x, task in enumerate(self.tasks):
|
170
|
+
logger.debug("Running Task #%d", x)
|
171
|
+
elements = task(elements, executor)
|
172
|
+
|
173
|
+
# Yield results processed by all tasks
|
174
|
+
yield from elements
|
175
|
+
|
176
|
+
def finalize(self):
|
177
|
+
"""
|
178
|
+
Runs task finalizer methods (if any) after all data processed.
|
179
|
+
"""
|
180
|
+
|
181
|
+
# Run task finalizers
|
182
|
+
for task in self.tasks:
|
183
|
+
if task.finalize:
|
184
|
+
task.finalize()
|
@@ -0,0 +1,99 @@
|
|
1
|
+
"""
|
2
|
+
Execute module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from multiprocessing.pool import Pool, ThreadPool
|
6
|
+
|
7
|
+
import torch.multiprocessing
|
8
|
+
|
9
|
+
|
10
|
+
class Execute:
|
11
|
+
"""
|
12
|
+
Supports sequential, multithreading and multiprocessing based execution of tasks.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, workers=None):
|
16
|
+
"""
|
17
|
+
Creates a new execute instance. Functions can be executed sequentially, in a thread pool
|
18
|
+
or in a process pool. Once created, the thread and/or process pool will stay open until the
|
19
|
+
close method is called.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
workers: number of workers for thread/process pools
|
23
|
+
"""
|
24
|
+
|
25
|
+
# Number of workers to use in thread/process pools
|
26
|
+
self.workers = workers
|
27
|
+
|
28
|
+
self.thread = None
|
29
|
+
self.process = None
|
30
|
+
|
31
|
+
def __del__(self):
|
32
|
+
self.close()
|
33
|
+
|
34
|
+
def __enter__(self):
|
35
|
+
return self
|
36
|
+
|
37
|
+
def __exit__(self, etype, value, traceback):
|
38
|
+
self.close()
|
39
|
+
|
40
|
+
def run(self, method, function, args):
|
41
|
+
"""
|
42
|
+
Runs multiple calls of function for each tuple in args. The method parameter controls if the calls are
|
43
|
+
sequential (method = None), multithreaded (method = "thread") or with multiprocessing (method="process").
|
44
|
+
|
45
|
+
Args:
|
46
|
+
method: run method - "thread" for multithreading, "process" for multiprocessing, otherwise runs sequentially
|
47
|
+
function: function to run
|
48
|
+
args: list of tuples with arguments to each call
|
49
|
+
"""
|
50
|
+
|
51
|
+
# Concurrent processing
|
52
|
+
if method and len(args) > 1:
|
53
|
+
pool = self.pool(method)
|
54
|
+
if pool:
|
55
|
+
return pool.starmap(function, args, 1)
|
56
|
+
|
57
|
+
# Sequential processing
|
58
|
+
return [function(*arg) for arg in args]
|
59
|
+
|
60
|
+
def pool(self, method):
|
61
|
+
"""
|
62
|
+
Gets a handle to a concurrent processing pool. This method will create the pool if it doesn't already exist.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
method: pool type - "thread" or "process"
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
concurrent processing pool or None if no pool of that type available
|
69
|
+
"""
|
70
|
+
|
71
|
+
if method == "thread":
|
72
|
+
if not self.thread:
|
73
|
+
self.thread = ThreadPool(self.workers)
|
74
|
+
|
75
|
+
return self.thread
|
76
|
+
|
77
|
+
if method == "process":
|
78
|
+
if not self.process:
|
79
|
+
# Importing torch.multiprocessing will register torch shared memory serialization for cuda
|
80
|
+
self.process = Pool(self.workers, context=torch.multiprocessing.get_context("spawn"))
|
81
|
+
|
82
|
+
return self.process
|
83
|
+
|
84
|
+
return None
|
85
|
+
|
86
|
+
def close(self):
|
87
|
+
"""
|
88
|
+
Closes concurrent processing pools.
|
89
|
+
"""
|
90
|
+
|
91
|
+
if hasattr(self, "thread") and self.thread:
|
92
|
+
self.thread.close()
|
93
|
+
self.thread.join()
|
94
|
+
self.thread = None
|
95
|
+
|
96
|
+
if hasattr(self, "process") and self.process:
|
97
|
+
self.process.close()
|
98
|
+
self.process.join()
|
99
|
+
self.process = None
|
@@ -0,0 +1,42 @@
|
|
1
|
+
"""
|
2
|
+
Workflow factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Workflow
|
6
|
+
from .task import TaskFactory
|
7
|
+
|
8
|
+
|
9
|
+
class WorkflowFactory:
|
10
|
+
"""
|
11
|
+
Workflow factory. Creates new Workflow instances.
|
12
|
+
"""
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def create(config, name):
|
16
|
+
"""
|
17
|
+
Creates a new Workflow instance.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
config: Workflow configuration
|
21
|
+
name: Workflow name
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Workflow
|
25
|
+
"""
|
26
|
+
|
27
|
+
# Resolve workflow tasks
|
28
|
+
tasks = []
|
29
|
+
for tconfig in config["tasks"]:
|
30
|
+
task = tconfig.pop("task") if "task" in tconfig else ""
|
31
|
+
tasks.append(TaskFactory.create(tconfig, task))
|
32
|
+
|
33
|
+
config["tasks"] = tasks
|
34
|
+
|
35
|
+
if "stream" in config:
|
36
|
+
sconfig = config["stream"]
|
37
|
+
task = sconfig.pop("task") if "task" in sconfig else "stream"
|
38
|
+
|
39
|
+
config["stream"] = TaskFactory.create(sconfig, task)
|
40
|
+
|
41
|
+
# Create workflow
|
42
|
+
return Workflow(**config, name=name)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Task imports
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Task
|
6
|
+
from .console import ConsoleTask
|
7
|
+
from .export import ExportTask
|
8
|
+
from .factory import TaskFactory
|
9
|
+
from .file import FileTask
|
10
|
+
from .image import ImageTask
|
11
|
+
from .retrieve import RetrieveTask
|
12
|
+
from .service import ServiceTask
|
13
|
+
from .storage import StorageTask
|
14
|
+
from .stream import StreamTask
|
15
|
+
from .template import *
|
16
|
+
from .template import RagTask as ExtractorTask
|
17
|
+
from .url import UrlTask
|
18
|
+
from .workflow import WorkflowTask
|