mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,140 @@
|
|
1
|
+
"""
|
2
|
+
Entity module
|
3
|
+
"""
|
4
|
+
|
5
|
+
# Conditional import
|
6
|
+
try:
|
7
|
+
from gliner import GLiNER
|
8
|
+
|
9
|
+
GLINER = True
|
10
|
+
except ImportError:
|
11
|
+
GLINER = False
|
12
|
+
|
13
|
+
from huggingface_hub.errors import HFValidationError
|
14
|
+
from transformers.utils import cached_file
|
15
|
+
|
16
|
+
from ...models import Models
|
17
|
+
from ..hfpipeline import HFPipeline
|
18
|
+
|
19
|
+
|
20
|
+
class Entity(HFPipeline):
|
21
|
+
"""
|
22
|
+
Applies a token classifier to text and extracts entity/label combinations.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
|
26
|
+
# Create a new entity pipeline
|
27
|
+
self.gliner = self.isgliner(path)
|
28
|
+
if self.gliner:
|
29
|
+
if not GLINER:
|
30
|
+
raise ImportError('GLiNER is not available - install "pipeline" extra to enable')
|
31
|
+
|
32
|
+
# GLiNER entity pipeline
|
33
|
+
self.pipeline = GLiNER.from_pretrained(path)
|
34
|
+
self.pipeline = self.pipeline.to(Models.device(Models.deviceid(gpu)))
|
35
|
+
else:
|
36
|
+
# Standard entity pipeline
|
37
|
+
super().__init__("token-classification", path, quantize, gpu, model, **kwargs)
|
38
|
+
|
39
|
+
def __call__(self, text, labels=None, aggregate="simple", flatten=None, join=False, workers=0):
|
40
|
+
"""
|
41
|
+
Applies a token classifier to text and extracts entity/label combinations.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
text: text|list
|
45
|
+
labels: list of entity type labels to accept, defaults to None which accepts all
|
46
|
+
aggregate: method to combine multi token entities - options are "simple" (default), "first", "average" or "max"
|
47
|
+
flatten: flatten output to a list of labels if present. Accepts a boolean or float value to only keep scores greater than that number.
|
48
|
+
join: joins flattened output into a string if True, ignored if flatten not set
|
49
|
+
workers: number of concurrent workers to use for processing data, defaults to None
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
list of (entity, entity type, score) or list of entities depending on flatten parameter
|
53
|
+
"""
|
54
|
+
|
55
|
+
# Run token classification pipeline
|
56
|
+
results = self.execute(text, labels, aggregate, workers)
|
57
|
+
|
58
|
+
# Convert results to a list if necessary
|
59
|
+
if isinstance(text, str):
|
60
|
+
results = [results]
|
61
|
+
|
62
|
+
# Score threshold when flatten is set
|
63
|
+
threshold = 0.0 if isinstance(flatten, bool) else flatten
|
64
|
+
|
65
|
+
# Extract entities if flatten set, otherwise extract (entity, entity type, score) tuples
|
66
|
+
outputs = []
|
67
|
+
for result in results:
|
68
|
+
if flatten:
|
69
|
+
output = [r["word"] for r in result if self.accept(r["entity_group"], labels) and r["score"] >= threshold]
|
70
|
+
outputs.append(" ".join(output) if join else output)
|
71
|
+
else:
|
72
|
+
outputs.append([(r["word"], r["entity_group"], float(r["score"])) for r in result if self.accept(r["entity_group"], labels)])
|
73
|
+
|
74
|
+
return outputs[0] if isinstance(text, str) else outputs
|
75
|
+
|
76
|
+
def isgliner(self, path):
|
77
|
+
"""
|
78
|
+
Tests if path is a GLiNER model.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
path: model path
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
True if this is a GLiNER model, False otherwise
|
85
|
+
"""
|
86
|
+
|
87
|
+
try:
|
88
|
+
# Test if this model has a gliner_config.json file
|
89
|
+
return cached_file(path_or_repo_id=path, filename="gliner_config.json") is not None
|
90
|
+
|
91
|
+
# Ignore this error - invalid repo or directory
|
92
|
+
except (HFValidationError, OSError):
|
93
|
+
pass
|
94
|
+
|
95
|
+
return False
|
96
|
+
|
97
|
+
def execute(self, text, labels, aggregate, workers):
|
98
|
+
"""
|
99
|
+
Runs the entity extraction pipeline.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
text: text|list
|
103
|
+
labels: list of entity type labels to accept, defaults to None which accepts all
|
104
|
+
aggregate: method to combine multi token entities - options are "simple" (default), "first", "average" or "max"
|
105
|
+
workers: number of concurrent workers to use for processing data, defaults to None
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
list of entities and labels
|
109
|
+
"""
|
110
|
+
|
111
|
+
if self.gliner:
|
112
|
+
# Extract entities with GLiNER. Use default CoNLL-2003 labels when not otherwise provided.
|
113
|
+
results = self.pipeline.batch_predict_entities(
|
114
|
+
text if isinstance(text, list) else [text], labels if labels else ["person", "organization", "location"]
|
115
|
+
)
|
116
|
+
|
117
|
+
# Map results to same format as Transformers token classifier
|
118
|
+
entities = []
|
119
|
+
for result in results:
|
120
|
+
entities.append([{"word": x["text"], "entity_group": x["label"], "score": x["score"]} for x in result])
|
121
|
+
|
122
|
+
# Return extracted entities
|
123
|
+
return entities if isinstance(text, list) else entities[0]
|
124
|
+
|
125
|
+
# Standard Transformers token classification pipeline
|
126
|
+
return self.pipeline(text, aggregation_strategy=aggregate, num_workers=workers)
|
127
|
+
|
128
|
+
def accept(self, etype, labels):
|
129
|
+
"""
|
130
|
+
Determines if entity type is in valid entity type.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
etype: entity type
|
134
|
+
labels: list of entities to accept
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
if etype is accepted
|
138
|
+
"""
|
139
|
+
|
140
|
+
return not labels or etype in labels
|
@@ -0,0 +1,137 @@
|
|
1
|
+
"""
|
2
|
+
Labels module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..hfpipeline import HFPipeline
|
6
|
+
|
7
|
+
|
8
|
+
class Labels(HFPipeline):
|
9
|
+
"""
|
10
|
+
Applies a text classifier to text. Supports zero shot and standard text classification models
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, path=None, quantize=False, gpu=True, model=None, dynamic=True, **kwargs):
|
14
|
+
super().__init__("zero-shot-classification" if dynamic else "text-classification", path, quantize, gpu, model, **kwargs)
|
15
|
+
|
16
|
+
# Set if labels are dynamic (zero shot) or fixed (standard text classification)
|
17
|
+
self.dynamic = dynamic
|
18
|
+
|
19
|
+
def __call__(self, text, labels=None, multilabel=False, flatten=None, workers=0, **kwargs):
|
20
|
+
"""
|
21
|
+
Applies a text classifier to text. Returns a list of (id, score) sorted by highest score,
|
22
|
+
where id is the index in labels. For zero shot classification, a list of labels is required.
|
23
|
+
For text classification models, a list of labels is optional, otherwise all trained labels are returned.
|
24
|
+
|
25
|
+
This method supports text as a string or a list. If the input is a string, the return
|
26
|
+
type is a 1D list of (id, score). If text is a list, a 2D list of (id, score) is
|
27
|
+
returned with a row per string.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
text: text|list
|
31
|
+
labels: list of labels
|
32
|
+
multilabel: labels are independent if True, scores are normalized to sum to 1 per text item if False, raw scores returned if None
|
33
|
+
flatten: flatten output to a list of labels if present. Accepts a boolean or float value to only keep scores greater than that number.
|
34
|
+
workers: number of concurrent workers to use for processing data, defaults to None
|
35
|
+
kwargs: additional keyword args
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
list of (id, score) or list of labels depending on flatten parameter
|
39
|
+
"""
|
40
|
+
|
41
|
+
if self.dynamic:
|
42
|
+
# Run zero shot classification pipeline
|
43
|
+
results = self.pipeline(text, labels, multi_label=multilabel, truncation=True, num_workers=workers)
|
44
|
+
else:
|
45
|
+
# Set classification function based on inputs
|
46
|
+
function = "none" if multilabel is None else "sigmoid" if multilabel or len(self.labels()) == 1 else "softmax"
|
47
|
+
|
48
|
+
# Run text classification pipeline
|
49
|
+
results = self.pipeline(text, top_k=None, function_to_apply=function, num_workers=workers, **kwargs)
|
50
|
+
|
51
|
+
# Convert results to a list if necessary
|
52
|
+
if isinstance(text, str):
|
53
|
+
results = [results]
|
54
|
+
|
55
|
+
# Build list of outputs and return
|
56
|
+
outputs = self.outputs(results, labels, flatten)
|
57
|
+
return outputs[0] if isinstance(text, str) else outputs
|
58
|
+
|
59
|
+
def labels(self):
|
60
|
+
"""
|
61
|
+
Returns a list of all text classification model labels sorted in index order.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
list of labels
|
65
|
+
"""
|
66
|
+
|
67
|
+
return list(self.pipeline.model.config.id2label.values())
|
68
|
+
|
69
|
+
def outputs(self, results, labels, flatten):
|
70
|
+
"""
|
71
|
+
Processes pipeline results and builds outputs.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
results: pipeline results
|
75
|
+
labels: list of labels
|
76
|
+
flatten: flatten output to a list of labels if present. Accepts a boolean or float value to only keep scores greater than that number.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
list of outputs
|
80
|
+
"""
|
81
|
+
|
82
|
+
outputs = []
|
83
|
+
threshold = 0.0 if isinstance(flatten, bool) else flatten
|
84
|
+
|
85
|
+
for result in results:
|
86
|
+
if self.dynamic:
|
87
|
+
if flatten:
|
88
|
+
result = [label for x, label in enumerate(result["labels"]) if result["scores"][x] >= threshold]
|
89
|
+
outputs.append(result[:1] if isinstance(flatten, bool) else result)
|
90
|
+
else:
|
91
|
+
outputs.append([(labels.index(label), result["scores"][x]) for x, label in enumerate(result["labels"])])
|
92
|
+
else:
|
93
|
+
if flatten:
|
94
|
+
result = [x["label"] for x in result if x["score"] >= threshold and (not labels or x["label"] in labels)]
|
95
|
+
outputs.append(result[:1] if isinstance(flatten, bool) else result)
|
96
|
+
else:
|
97
|
+
# Filter results using labels, if provided
|
98
|
+
outputs.append(self.limit(result, labels))
|
99
|
+
|
100
|
+
return outputs
|
101
|
+
|
102
|
+
def limit(self, result, labels):
|
103
|
+
"""
|
104
|
+
Filter result using labels. If labels is None, original result is returned.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
result: results array sorted by score descending
|
108
|
+
labels: list of labels or None
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
filtered results
|
112
|
+
"""
|
113
|
+
|
114
|
+
# Get config
|
115
|
+
config = self.pipeline.model.config
|
116
|
+
|
117
|
+
# Resolve label ids for labels
|
118
|
+
result = [(config.label2id.get(x["label"], 0), x["score"]) for x in result]
|
119
|
+
|
120
|
+
if labels:
|
121
|
+
matches = []
|
122
|
+
for label in labels:
|
123
|
+
# Lookup label keys from model config
|
124
|
+
if label.isdigit():
|
125
|
+
label = int(label)
|
126
|
+
keys = list(config.id2label.keys())
|
127
|
+
else:
|
128
|
+
label = label.lower()
|
129
|
+
keys = [x.lower() for x in config.label2id.keys()]
|
130
|
+
|
131
|
+
# Find and add label match
|
132
|
+
if label in keys:
|
133
|
+
matches.append(keys.index(label))
|
134
|
+
|
135
|
+
return [(label, score) for label, score in result if label in matches]
|
136
|
+
|
137
|
+
return result
|
@@ -0,0 +1,103 @@
|
|
1
|
+
"""
|
2
|
+
Late encoder module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from ...models import Models, PoolingFactory
|
9
|
+
from ..base import Pipeline
|
10
|
+
|
11
|
+
|
12
|
+
class LateEncoder(Pipeline):
|
13
|
+
"""
|
14
|
+
Computes similarity between query and list of text using a late interaction model.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, path=None, **kwargs):
|
18
|
+
# Get device
|
19
|
+
self.device = Models.device(Models.deviceid(kwargs.get("gpu", True)))
|
20
|
+
|
21
|
+
# Load model
|
22
|
+
self.model = PoolingFactory.create(
|
23
|
+
{
|
24
|
+
"method": kwargs.get("method"),
|
25
|
+
"path": path if path else "colbert-ir/colbertv2.0",
|
26
|
+
"device": self.device,
|
27
|
+
"tokenizer": kwargs.get("tokenizer"),
|
28
|
+
"maxlength": kwargs.get("maxlength"),
|
29
|
+
"modelargs": {**kwargs.get("vectors", {}), **{"muvera": None}},
|
30
|
+
}
|
31
|
+
)
|
32
|
+
|
33
|
+
def __call__(self, query, texts, limit=None):
|
34
|
+
"""
|
35
|
+
Computes the similarity between query and list of text. Returns a list of
|
36
|
+
(id, score) sorted by highest score, where id is the index in texts.
|
37
|
+
|
38
|
+
This method supports query as a string or a list. If the input is a string,
|
39
|
+
the return type is a 1D list of (id, score). If text is a list, a 2D list
|
40
|
+
of (id, score) is returned with a row per string.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
query: query text|list
|
44
|
+
texts: list of text
|
45
|
+
limit: maximum comparisons to return, defaults to all
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
list of (id, score)
|
49
|
+
"""
|
50
|
+
|
51
|
+
queries = [query] if isinstance(query, str) else query
|
52
|
+
|
53
|
+
# Encode text to vectors
|
54
|
+
queries = self.encode(queries, "query")
|
55
|
+
data = self.encode(texts, "data") if isinstance(texts[0], str) else texts
|
56
|
+
|
57
|
+
# Compute maximum similarity score
|
58
|
+
scores = []
|
59
|
+
for q in queries:
|
60
|
+
scores.extend(self.score(q.unsqueeze(0), data, limit))
|
61
|
+
|
62
|
+
return scores[0] if isinstance(query, str) else scores
|
63
|
+
|
64
|
+
def encode(self, data, category):
|
65
|
+
"""
|
66
|
+
Encodes a batch of data using the underlying model.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
data: input data
|
70
|
+
category: encoding category
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
encoded data
|
74
|
+
"""
|
75
|
+
|
76
|
+
return torch.from_numpy(self.model.encode(data, category=category)).to(self.device)
|
77
|
+
|
78
|
+
def score(self, queries, data, limit):
|
79
|
+
"""
|
80
|
+
Computes the maximum similarity score between query vectors and data vectors.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
queries: query vectors
|
84
|
+
data: data vectors
|
85
|
+
limit: query limit
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
list of (id, score)
|
89
|
+
"""
|
90
|
+
|
91
|
+
# Compute bulk dot product using einstein notation
|
92
|
+
scores = torch.einsum("ash,bth->abst", queries, data).max(axis=-1).values.mean(axis=-1)
|
93
|
+
scores = scores.cpu().numpy()
|
94
|
+
|
95
|
+
# Get top n matching indices and scores
|
96
|
+
indices = np.argpartition(-scores, limit if limit and limit < scores.shape[0] else scores.shape[0] - 1)[:, :limit]
|
97
|
+
scores = np.take_along_axis(scores, indices, axis=1)
|
98
|
+
|
99
|
+
results = []
|
100
|
+
for x, index in enumerate(indices):
|
101
|
+
results.append(list(zip(index.tolist(), scores[x].tolist())))
|
102
|
+
|
103
|
+
return results
|
@@ -0,0 +1,48 @@
|
|
1
|
+
"""
|
2
|
+
Questions module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..hfpipeline import HFPipeline
|
6
|
+
|
7
|
+
|
8
|
+
class Questions(HFPipeline):
|
9
|
+
"""
|
10
|
+
Runs extractive QA for a series of questions and contexts.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
|
14
|
+
super().__init__("question-answering", path, quantize, gpu, model, **kwargs)
|
15
|
+
|
16
|
+
def __call__(self, questions, contexts, workers=0):
|
17
|
+
"""
|
18
|
+
Runs a extractive question-answering model against each question-context pair, finding the best answers.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
questions: list of questions
|
22
|
+
contexts: list of contexts to pull answers from
|
23
|
+
workers: number of concurrent workers to use for processing data, defaults to None
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
list of answers
|
27
|
+
"""
|
28
|
+
|
29
|
+
answers = []
|
30
|
+
|
31
|
+
for x, question in enumerate(questions):
|
32
|
+
if question and contexts[x]:
|
33
|
+
# Run the QA pipeline
|
34
|
+
result = self.pipeline(question=question, context=contexts[x], num_workers=workers)
|
35
|
+
|
36
|
+
# Get answer and score
|
37
|
+
answer, score = result["answer"], result["score"]
|
38
|
+
|
39
|
+
# Require score to be at least 0.05
|
40
|
+
if score < 0.05:
|
41
|
+
answer = None
|
42
|
+
|
43
|
+
# Add answer
|
44
|
+
answers.append(answer)
|
45
|
+
else:
|
46
|
+
answers.append(None)
|
47
|
+
|
48
|
+
return answers
|
@@ -0,0 +1,57 @@
|
|
1
|
+
"""
|
2
|
+
Reranker module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..base import Pipeline
|
6
|
+
|
7
|
+
|
8
|
+
class Reranker(Pipeline):
|
9
|
+
"""
|
10
|
+
Runs embeddings queries and re-ranks them using a similarity pipeline. Note that content must be enabled with the
|
11
|
+
embeddings instance for this to work properly.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, embeddings, similarity):
|
15
|
+
"""
|
16
|
+
Creates a Reranker pipeline.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
embeddings: embeddings instance (content must be enabled)
|
20
|
+
similarity: similarity instance
|
21
|
+
"""
|
22
|
+
|
23
|
+
self.embeddings, self.similarity = embeddings, similarity
|
24
|
+
|
25
|
+
# pylint: disable=W0222
|
26
|
+
def __call__(self, query, limit=3, factor=10, **kwargs):
|
27
|
+
"""
|
28
|
+
Runs an embeddings search and re-ranks the results using a Similarity pipeline.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
query: query text|list
|
32
|
+
limit: maximum results
|
33
|
+
factor: factor to multiply limit by for the initial embeddings search
|
34
|
+
kwargs: additional arguments to pass to embeddings search
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
list of query results rescored using a Similarity pipeline
|
38
|
+
"""
|
39
|
+
|
40
|
+
queries = [query] if not isinstance(query, list) else query
|
41
|
+
|
42
|
+
# Run searches
|
43
|
+
results = self.embeddings.batchsearch(queries, limit * factor, **kwargs)
|
44
|
+
|
45
|
+
# Re-rank using similarity pipeline
|
46
|
+
ranked = []
|
47
|
+
for x, result in enumerate(results):
|
48
|
+
texts = [row["text"] for row in result]
|
49
|
+
|
50
|
+
# Score results and merge
|
51
|
+
for uid, score in self.similarity(queries[x], texts):
|
52
|
+
result[uid]["score"] = score
|
53
|
+
|
54
|
+
# Sort and take top n sorted results
|
55
|
+
ranked.append(sorted(result, key=lambda row: row["score"], reverse=True)[:limit])
|
56
|
+
|
57
|
+
return ranked[0] if isinstance(query, str) else ranked
|
@@ -0,0 +1,83 @@
|
|
1
|
+
"""
|
2
|
+
Similarity module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from .crossencoder import CrossEncoder
|
8
|
+
from .labels import Labels
|
9
|
+
from .lateencoder import LateEncoder
|
10
|
+
|
11
|
+
|
12
|
+
class Similarity(Labels):
|
13
|
+
"""
|
14
|
+
Computes similarity between query and list of text using a transformers model.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, path=None, quantize=False, gpu=True, model=None, dynamic=True, crossencode=False, lateencode=False, **kwargs):
|
18
|
+
self.crossencoder, self.lateencoder = None, None
|
19
|
+
|
20
|
+
if lateencode:
|
21
|
+
# Load a late interaction encoder if lateencode set to True
|
22
|
+
self.lateencoder = LateEncoder(path=path, gpu=gpu, **kwargs)
|
23
|
+
else:
|
24
|
+
# Use zero-shot classification if dynamic is True and crossencode is False, otherwise use standard text classification
|
25
|
+
super().__init__(path, quantize, gpu, model, False if crossencode else dynamic, **kwargs)
|
26
|
+
|
27
|
+
# Load as a cross-encoder if crossencode set to True
|
28
|
+
self.crossencoder = CrossEncoder(model=self.pipeline) if crossencode else None
|
29
|
+
|
30
|
+
# pylint: disable=W0222
|
31
|
+
def __call__(self, query, texts, multilabel=True, **kwargs):
|
32
|
+
"""
|
33
|
+
Computes the similarity between query and list of text. Returns a list of
|
34
|
+
(id, score) sorted by highest score, where id is the index in texts.
|
35
|
+
|
36
|
+
This method supports query as a string or a list. If the input is a string,
|
37
|
+
the return type is a 1D list of (id, score). If text is a list, a 2D list
|
38
|
+
of (id, score) is returned with a row per string.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
query: query text|list
|
42
|
+
texts: list of text
|
43
|
+
multilabel: labels are independent if True, scores are normalized to sum to 1 per text item if False, raw scores returned if None
|
44
|
+
kwargs: additional keyword args
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
list of (id, score)
|
48
|
+
"""
|
49
|
+
|
50
|
+
if self.crossencoder:
|
51
|
+
# pylint: disable=E1102
|
52
|
+
return self.crossencoder(query, texts, multilabel)
|
53
|
+
|
54
|
+
if self.lateencoder:
|
55
|
+
return self.lateencoder(query, texts)
|
56
|
+
|
57
|
+
# Call Labels pipeline for texts using input query as the candidate label
|
58
|
+
scores = super().__call__(texts, [query] if isinstance(query, str) else query, multilabel, **kwargs)
|
59
|
+
|
60
|
+
# Sort on query index id
|
61
|
+
scores = [[score for _, score in sorted(row)] for row in scores]
|
62
|
+
|
63
|
+
# Transpose axes to get a list of text scores for each query
|
64
|
+
scores = np.array(scores).T.tolist()
|
65
|
+
|
66
|
+
# Build list of (id, score) per query sorted by highest score
|
67
|
+
scores = [sorted(enumerate(row), key=lambda x: x[1], reverse=True) for row in scores]
|
68
|
+
|
69
|
+
return scores[0] if isinstance(query, str) else scores
|
70
|
+
|
71
|
+
def encode(self, data, category):
|
72
|
+
"""
|
73
|
+
Encodes a batch of data using the underlying model.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
data: input data
|
77
|
+
category: encoding category
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
encoded data
|
81
|
+
"""
|
82
|
+
|
83
|
+
return self.lateencoder.encode(data, category) if self.lateencoder else data
|
@@ -0,0 +1,98 @@
|
|
1
|
+
"""
|
2
|
+
Summary module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import re
|
6
|
+
|
7
|
+
from ..hfpipeline import HFPipeline
|
8
|
+
|
9
|
+
|
10
|
+
class Summary(HFPipeline):
|
11
|
+
"""
|
12
|
+
Summarizes text.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
|
16
|
+
super().__init__("summarization", path, quantize, gpu, model, **kwargs)
|
17
|
+
|
18
|
+
def __call__(self, text, minlength=None, maxlength=None, workers=0):
|
19
|
+
"""
|
20
|
+
Runs a summarization model against a block of text.
|
21
|
+
|
22
|
+
This method supports text as a string or a list. If the input is a string, the return
|
23
|
+
type is text. If text is a list, a list of text is returned with a row per block of text.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
text: text|list
|
27
|
+
minlength: minimum length for summary
|
28
|
+
maxlength: maximum length for summary
|
29
|
+
workers: number of concurrent workers to use for processing data, defaults to None
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
summary text
|
33
|
+
"""
|
34
|
+
|
35
|
+
# Validate text length greater than max length
|
36
|
+
check = maxlength if maxlength else self.maxlength()
|
37
|
+
|
38
|
+
# Skip text shorter than max length
|
39
|
+
texts = text if isinstance(text, list) else [text]
|
40
|
+
params = [(x, text if len(text) >= check else None) for x, text in enumerate(texts)]
|
41
|
+
|
42
|
+
# Build keyword arguments
|
43
|
+
kwargs = self.args(minlength, maxlength)
|
44
|
+
|
45
|
+
inputs = [text for _, text in params if text]
|
46
|
+
if inputs:
|
47
|
+
# Run summarization pipeline
|
48
|
+
results = self.pipeline(inputs, num_workers=workers, **kwargs)
|
49
|
+
|
50
|
+
# Pull out summary text
|
51
|
+
results = iter([self.clean(x["summary_text"]) for x in results])
|
52
|
+
results = [next(results) if text else texts[x] for x, text in params]
|
53
|
+
else:
|
54
|
+
# Return original
|
55
|
+
results = texts
|
56
|
+
|
57
|
+
return results[0] if isinstance(text, str) else results
|
58
|
+
|
59
|
+
def clean(self, text):
|
60
|
+
"""
|
61
|
+
Applies a series of rules to clean extracted text.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
text: input text
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
clean text
|
68
|
+
"""
|
69
|
+
|
70
|
+
text = re.sub(r"\s*\.\s*", ". ", text)
|
71
|
+
text = text.strip()
|
72
|
+
|
73
|
+
return text
|
74
|
+
|
75
|
+
def args(self, minlength, maxlength):
|
76
|
+
"""
|
77
|
+
Builds keyword arguments.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
minlength: minimum length for summary
|
81
|
+
maxlength: maximum length for summary
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
keyword arguments
|
85
|
+
"""
|
86
|
+
|
87
|
+
kwargs = {"truncation": True}
|
88
|
+
if minlength:
|
89
|
+
kwargs["min_length"] = minlength
|
90
|
+
if maxlength:
|
91
|
+
kwargs["max_length"] = maxlength
|
92
|
+
kwargs["max_new_tokens"] = None
|
93
|
+
|
94
|
+
# Default minlength if not provided or it's bigger than maxlength
|
95
|
+
if "min_length" not in kwargs or kwargs["min_length"] > kwargs["max_length"]:
|
96
|
+
kwargs["min_length"] = kwargs["max_length"]
|
97
|
+
|
98
|
+
return kwargs
|