mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,55 @@
|
|
1
|
+
"""
|
2
|
+
Caption module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
PIL = True
|
10
|
+
except ImportError:
|
11
|
+
PIL = False
|
12
|
+
|
13
|
+
from ..hfpipeline import HFPipeline
|
14
|
+
|
15
|
+
|
16
|
+
class Caption(HFPipeline):
|
17
|
+
"""
|
18
|
+
Constructs captions for images.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
|
22
|
+
if not PIL:
|
23
|
+
raise ImportError('Captions pipeline is not available - install "pipeline" extra to enable')
|
24
|
+
|
25
|
+
# Call parent constructor
|
26
|
+
super().__init__("image-to-text", path, quantize, gpu, model, **kwargs)
|
27
|
+
|
28
|
+
def __call__(self, images):
|
29
|
+
"""
|
30
|
+
Builds captions for images.
|
31
|
+
|
32
|
+
This method supports a single image or a list of images. If the input is an image, the return
|
33
|
+
type is a string. If text is a list, a list of strings is returned
|
34
|
+
|
35
|
+
Args:
|
36
|
+
images: image|list
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
list of captions
|
40
|
+
"""
|
41
|
+
|
42
|
+
# Convert single element to list
|
43
|
+
values = [images] if not isinstance(images, list) else images
|
44
|
+
|
45
|
+
# Open images if file strings
|
46
|
+
values = [Image.open(image) if isinstance(image, str) else image for image in values]
|
47
|
+
|
48
|
+
# Get and clean captions
|
49
|
+
captions = []
|
50
|
+
for result in self.pipeline(values):
|
51
|
+
text = " ".join([r["generated_text"] for r in result]).strip()
|
52
|
+
captions.append(text)
|
53
|
+
|
54
|
+
# Return single element if single element passed in
|
55
|
+
return captions[0] if not isinstance(images, list) else captions
|
@@ -0,0 +1,90 @@
|
|
1
|
+
"""
|
2
|
+
ImageHash module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
# Conditional import
|
8
|
+
try:
|
9
|
+
from PIL import Image
|
10
|
+
import imagehash
|
11
|
+
|
12
|
+
PIL = True
|
13
|
+
except ImportError:
|
14
|
+
PIL = False
|
15
|
+
|
16
|
+
from ..base import Pipeline
|
17
|
+
|
18
|
+
|
19
|
+
class ImageHash(Pipeline):
|
20
|
+
"""
|
21
|
+
Generates perceptual image hashes. These hashes can be used to detect near-duplicate images. This method is not
|
22
|
+
backed by machine learning models and not intended to find conceptually similar images.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, algorithm="average", size=8, strings=True):
|
26
|
+
"""
|
27
|
+
Creates an ImageHash pipeline.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
algorithm: image hashing algorithm (average, perceptual, difference, wavelet, color)
|
31
|
+
size: hash size
|
32
|
+
strings: outputs hex strings if True (default), otherwise the pipeline returns numpy arrays
|
33
|
+
"""
|
34
|
+
|
35
|
+
if not PIL:
|
36
|
+
raise ImportError('ImageHash pipeline is not available - install "pipeline" extra to enable')
|
37
|
+
|
38
|
+
self.algorithm = algorithm
|
39
|
+
self.size = size
|
40
|
+
self.strings = strings
|
41
|
+
|
42
|
+
def __call__(self, images):
|
43
|
+
"""
|
44
|
+
Generates perceptual image hashes.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
images: image|list
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
list of hashes
|
51
|
+
"""
|
52
|
+
|
53
|
+
# Convert single element to list
|
54
|
+
values = [images] if not isinstance(images, list) else images
|
55
|
+
|
56
|
+
# Open images if file strings
|
57
|
+
values = [Image.open(image) if isinstance(image, str) else image for image in values]
|
58
|
+
|
59
|
+
# Convert images to hashes
|
60
|
+
hashes = [self.ihash(image) for image in values]
|
61
|
+
|
62
|
+
# Return single element if single element passed in
|
63
|
+
return hashes[0] if not isinstance(images, list) else hashes
|
64
|
+
|
65
|
+
def ihash(self, image):
|
66
|
+
"""
|
67
|
+
Gets an image hash for image.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
image: PIL image
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
hash as hex string
|
74
|
+
"""
|
75
|
+
|
76
|
+
# Apply hash algorithm
|
77
|
+
if self.algorithm == "perceptual":
|
78
|
+
data = imagehash.phash(image, self.size)
|
79
|
+
elif self.algorithm == "difference":
|
80
|
+
data = imagehash.dhash(image, self.size)
|
81
|
+
elif self.algorithm == "wavelet":
|
82
|
+
data = imagehash.whash(image, self.size)
|
83
|
+
elif self.algorithm == "color":
|
84
|
+
data = imagehash.colorhash(image, self.size)
|
85
|
+
else:
|
86
|
+
# Default to average hash
|
87
|
+
data = imagehash.average_hash(image, self.size)
|
88
|
+
|
89
|
+
# Convert to output data type
|
90
|
+
return str(data) if self.strings else data.hash.astype(np.float32).reshape(-1)
|
@@ -0,0 +1,80 @@
|
|
1
|
+
"""
|
2
|
+
Objects module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
PIL = True
|
10
|
+
except ImportError:
|
11
|
+
PIL = False
|
12
|
+
|
13
|
+
from ..hfpipeline import HFPipeline
|
14
|
+
|
15
|
+
|
16
|
+
class Objects(HFPipeline):
|
17
|
+
"""
|
18
|
+
Applies object detection models to images. Supports both object detection models and image classification models.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, path=None, quantize=False, gpu=True, model=None, classification=False, threshold=0.9, **kwargs):
|
22
|
+
if not PIL:
|
23
|
+
raise ImportError('Objects pipeline is not available - install "pipeline" extra to enable')
|
24
|
+
|
25
|
+
super().__init__("image-classification" if classification else "object-detection", path, quantize, gpu, model, **kwargs)
|
26
|
+
|
27
|
+
self.classification = classification
|
28
|
+
self.threshold = threshold
|
29
|
+
|
30
|
+
def __call__(self, images, flatten=False, workers=0):
|
31
|
+
"""
|
32
|
+
Applies object detection/image classification models to images. Returns a list of (label, score).
|
33
|
+
|
34
|
+
This method supports a single image or a list of images. If the input is an image, the return
|
35
|
+
type is a 1D list of (label, score). If text is a list, a 2D list of (label, score) is
|
36
|
+
returned with a row per image.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
images: image|list
|
40
|
+
flatten: flatten output to a list of objects
|
41
|
+
workers: number of concurrent workers to use for processing data, defaults to None
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
list of (label, score)
|
45
|
+
"""
|
46
|
+
|
47
|
+
# Convert single element to list
|
48
|
+
values = [images] if not isinstance(images, list) else images
|
49
|
+
|
50
|
+
# Open images if file strings
|
51
|
+
values = [Image.open(image) if isinstance(image, str) else image for image in values]
|
52
|
+
|
53
|
+
# Run pipeline
|
54
|
+
results = (
|
55
|
+
self.pipeline(values, num_workers=workers)
|
56
|
+
if self.classification
|
57
|
+
else self.pipeline(values, threshold=self.threshold, num_workers=workers)
|
58
|
+
)
|
59
|
+
|
60
|
+
# Build list of (id, score)
|
61
|
+
outputs = []
|
62
|
+
for result in results:
|
63
|
+
# Convert to (label, score) tuples
|
64
|
+
result = [(x["label"], x["score"]) for x in result if x["score"] > self.threshold]
|
65
|
+
|
66
|
+
# Sort by score descending
|
67
|
+
result = sorted(result, key=lambda x: x[1], reverse=True)
|
68
|
+
|
69
|
+
# Deduplicate labels
|
70
|
+
unique = set()
|
71
|
+
elements = []
|
72
|
+
for label, score in result:
|
73
|
+
if label not in unique:
|
74
|
+
elements.append(label if flatten else (label, score))
|
75
|
+
unique.add(label)
|
76
|
+
|
77
|
+
outputs.append(elements)
|
78
|
+
|
79
|
+
# Return single element if single element passed in
|
80
|
+
return outputs[0] if not isinstance(images, list) else outputs
|
@@ -0,0 +1,86 @@
|
|
1
|
+
"""
|
2
|
+
Factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ...util import Resolver
|
6
|
+
|
7
|
+
from .huggingface import HFGeneration
|
8
|
+
from .litellm import LiteLLM
|
9
|
+
from .llama import LlamaCpp
|
10
|
+
|
11
|
+
|
12
|
+
class GenerationFactory:
|
13
|
+
"""
|
14
|
+
Methods to create generative models.
|
15
|
+
"""
|
16
|
+
|
17
|
+
@staticmethod
|
18
|
+
def create(path, method, **kwargs):
|
19
|
+
"""
|
20
|
+
Creates a new Generation instance.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
path: model path
|
24
|
+
method: llm framework
|
25
|
+
kwargs: model keyword arguments
|
26
|
+
"""
|
27
|
+
|
28
|
+
# Derive method
|
29
|
+
method = GenerationFactory.method(path, method)
|
30
|
+
|
31
|
+
# LiteLLM generation
|
32
|
+
if method == "litellm":
|
33
|
+
return LiteLLM(path, **kwargs)
|
34
|
+
|
35
|
+
# llama.cpp generation
|
36
|
+
if method == "llama.cpp":
|
37
|
+
return LlamaCpp(path, **kwargs)
|
38
|
+
|
39
|
+
# Hugging Face Transformers generation
|
40
|
+
if method == "transformers":
|
41
|
+
return HFGeneration(path, **kwargs)
|
42
|
+
|
43
|
+
# Resolve custom method
|
44
|
+
return GenerationFactory.resolve(path, method, **kwargs)
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def method(path, method):
|
48
|
+
"""
|
49
|
+
Get or derives the LLM framework.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
path: model path
|
53
|
+
method: llm framework
|
54
|
+
|
55
|
+
Return:
|
56
|
+
llm framework
|
57
|
+
"""
|
58
|
+
|
59
|
+
if not method:
|
60
|
+
if LiteLLM.ismodel(path):
|
61
|
+
method = "litellm"
|
62
|
+
elif LlamaCpp.ismodel(path):
|
63
|
+
method = "llama.cpp"
|
64
|
+
else:
|
65
|
+
method = "transformers"
|
66
|
+
|
67
|
+
return method
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def resolve(path, method, **kwargs):
|
71
|
+
"""
|
72
|
+
Attempt to resolve a custom LLM framework.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
path: model path
|
76
|
+
method: llm framework
|
77
|
+
kwargs: model keyword arguments
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
Generation instance
|
81
|
+
"""
|
82
|
+
|
83
|
+
try:
|
84
|
+
return Resolver()(method)(path, **kwargs)
|
85
|
+
except Exception as e:
|
86
|
+
raise ImportError(f"Unable to resolve generation framework: '{method}'") from e
|
@@ -0,0 +1,173 @@
|
|
1
|
+
"""
|
2
|
+
Generation module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import re
|
6
|
+
|
7
|
+
from ...util import TemplateFormatter
|
8
|
+
|
9
|
+
|
10
|
+
class Generation:
|
11
|
+
"""
|
12
|
+
Base class for generative models. This class has common logic for building prompts and cleaning model results.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, path=None, template=None, **kwargs):
|
16
|
+
"""
|
17
|
+
Creates a new Generation instance.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
path: model path
|
21
|
+
template: prompt template
|
22
|
+
kwargs: additional keyword arguments
|
23
|
+
"""
|
24
|
+
|
25
|
+
self.path = path
|
26
|
+
self.template = template
|
27
|
+
self.kwargs = kwargs
|
28
|
+
|
29
|
+
def __call__(self, text, maxlength, stream, stop, defaultrole, stripthink, **kwargs):
|
30
|
+
"""
|
31
|
+
Generates text. Supports the following input formats:
|
32
|
+
|
33
|
+
- String or list of strings (instruction-tuned models must follow chat templates)
|
34
|
+
- List of dictionaries with `role` and `content` key-values or lists of lists
|
35
|
+
|
36
|
+
Args:
|
37
|
+
text: text|list
|
38
|
+
maxlength: maximum sequence length
|
39
|
+
stream: stream response if True, defaults to False
|
40
|
+
stop: list of stop strings
|
41
|
+
defaultrole: default role to apply to text inputs (prompt for raw prompts (default) or user for user chat messages)
|
42
|
+
stripthink: strip thinking tags, defaults to False
|
43
|
+
kwargs: additional generation keyword arguments
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
generated text
|
47
|
+
"""
|
48
|
+
|
49
|
+
# Format inputs
|
50
|
+
texts = [text] if isinstance(text, str) or isinstance(text[0], dict) else text
|
51
|
+
|
52
|
+
# Apply template, if necessary
|
53
|
+
if self.template:
|
54
|
+
formatter = TemplateFormatter()
|
55
|
+
texts = [formatter.format(self.template, text=x) if isinstance(x, str) else x for x in texts]
|
56
|
+
|
57
|
+
# Apply default role, if necessary
|
58
|
+
if defaultrole == "user":
|
59
|
+
texts = [[{"role": "user", "content": x}] if isinstance(x, str) else x for x in texts]
|
60
|
+
|
61
|
+
# Run pipeline
|
62
|
+
results = self.execute(texts, maxlength, stream, stop, **kwargs)
|
63
|
+
|
64
|
+
# Streaming generation
|
65
|
+
if stream:
|
66
|
+
return results
|
67
|
+
|
68
|
+
# Clean generated text
|
69
|
+
results = [self.clean(texts[x], result, stripthink) for x, result in enumerate(results)]
|
70
|
+
|
71
|
+
# Extract results based on inputs
|
72
|
+
return results[0] if isinstance(text, str) or isinstance(text[0], dict) else results
|
73
|
+
|
74
|
+
def isvision(self):
|
75
|
+
"""
|
76
|
+
Returns True if this LLM supports vision operations.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
True if this is a vision model
|
80
|
+
"""
|
81
|
+
|
82
|
+
return False
|
83
|
+
|
84
|
+
def execute(self, texts, maxlength, stream, stop, **kwargs):
|
85
|
+
"""
|
86
|
+
Runs a list of prompts through a generative model.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
texts: list of prompts to run
|
90
|
+
maxlength: maximum sequence length
|
91
|
+
stream: stream response if True, defaults to False
|
92
|
+
stop: list of stop strings
|
93
|
+
kwargs: additional generation keyword arguments
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
generated text
|
97
|
+
"""
|
98
|
+
|
99
|
+
# Streaming generation
|
100
|
+
if stream:
|
101
|
+
return self.stream(texts, maxlength, stream, stop, **kwargs)
|
102
|
+
|
103
|
+
# Full response as content elements
|
104
|
+
return list(self.stream(texts, maxlength, stream, stop, **kwargs))
|
105
|
+
|
106
|
+
def clean(self, prompt, result, stripthink):
|
107
|
+
"""
|
108
|
+
Applies a series of rules to clean generated text.
|
109
|
+
|
110
|
+
Args:
|
111
|
+
prompt: original input prompt
|
112
|
+
result: result text
|
113
|
+
stripthink: removes thinking tags if true
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
clean text
|
117
|
+
"""
|
118
|
+
|
119
|
+
# Replace input prompt
|
120
|
+
text = result.replace(prompt, "") if isinstance(prompt, str) else result
|
121
|
+
|
122
|
+
# Replace thinking tags, if necessary
|
123
|
+
text = re.sub(r"(?s)<think>.+?</think>", "", text).strip() if stripthink else text
|
124
|
+
|
125
|
+
# Apply text cleaning rules
|
126
|
+
return text.replace("$=", "<=").strip()
|
127
|
+
|
128
|
+
def response(self, result):
|
129
|
+
"""
|
130
|
+
Parses response content from the result. This supports both standard and streaming
|
131
|
+
generation.
|
132
|
+
|
133
|
+
For standard generation, the full response is returned. For streaming generation,
|
134
|
+
this method will stream chunks of content.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
result: LLM response
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
response
|
141
|
+
"""
|
142
|
+
|
143
|
+
streamed = False
|
144
|
+
for chunk in result:
|
145
|
+
# Expects one of the following parameter paths
|
146
|
+
# - text
|
147
|
+
# - message.content
|
148
|
+
# - delta.content
|
149
|
+
data = chunk["choices"][0]
|
150
|
+
text = data.get("text", data.get("message", data.get("delta")))
|
151
|
+
text = text if isinstance(text, str) else text.get("content")
|
152
|
+
|
153
|
+
# Yield result if there is text AND it's not leading stream whitespace
|
154
|
+
if text is not None and (streamed or text.strip()):
|
155
|
+
yield (text.lstrip() if not streamed else text)
|
156
|
+
streamed = True
|
157
|
+
|
158
|
+
def stream(self, texts, maxlength, stream, stop, **kwargs):
|
159
|
+
"""
|
160
|
+
Streams LLM responses.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
texts: list of prompts to run
|
164
|
+
maxlength: maximum sequence length
|
165
|
+
stream: stream response if True, defaults to False
|
166
|
+
stop: list of stop strings
|
167
|
+
kwargs: additional generation keyword arguments
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
responses
|
171
|
+
"""
|
172
|
+
|
173
|
+
raise NotImplementedError
|