mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,140 @@
1
+ """
2
+ Entity module
3
+ """
4
+
5
+ # Conditional import
6
+ try:
7
+ from gliner import GLiNER
8
+
9
+ GLINER = True
10
+ except ImportError:
11
+ GLINER = False
12
+
13
+ from huggingface_hub.errors import HFValidationError
14
+ from transformers.utils import cached_file
15
+
16
+ from ...models import Models
17
+ from ..hfpipeline import HFPipeline
18
+
19
+
20
+ class Entity(HFPipeline):
21
+ """
22
+ Applies a token classifier to text and extracts entity/label combinations.
23
+ """
24
+
25
+ def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
26
+ # Create a new entity pipeline
27
+ self.gliner = self.isgliner(path)
28
+ if self.gliner:
29
+ if not GLINER:
30
+ raise ImportError('GLiNER is not available - install "pipeline" extra to enable')
31
+
32
+ # GLiNER entity pipeline
33
+ self.pipeline = GLiNER.from_pretrained(path)
34
+ self.pipeline = self.pipeline.to(Models.device(Models.deviceid(gpu)))
35
+ else:
36
+ # Standard entity pipeline
37
+ super().__init__("token-classification", path, quantize, gpu, model, **kwargs)
38
+
39
+ def __call__(self, text, labels=None, aggregate="simple", flatten=None, join=False, workers=0):
40
+ """
41
+ Applies a token classifier to text and extracts entity/label combinations.
42
+
43
+ Args:
44
+ text: text|list
45
+ labels: list of entity type labels to accept, defaults to None which accepts all
46
+ aggregate: method to combine multi token entities - options are "simple" (default), "first", "average" or "max"
47
+ flatten: flatten output to a list of labels if present. Accepts a boolean or float value to only keep scores greater than that number.
48
+ join: joins flattened output into a string if True, ignored if flatten not set
49
+ workers: number of concurrent workers to use for processing data, defaults to None
50
+
51
+ Returns:
52
+ list of (entity, entity type, score) or list of entities depending on flatten parameter
53
+ """
54
+
55
+ # Run token classification pipeline
56
+ results = self.execute(text, labels, aggregate, workers)
57
+
58
+ # Convert results to a list if necessary
59
+ if isinstance(text, str):
60
+ results = [results]
61
+
62
+ # Score threshold when flatten is set
63
+ threshold = 0.0 if isinstance(flatten, bool) else flatten
64
+
65
+ # Extract entities if flatten set, otherwise extract (entity, entity type, score) tuples
66
+ outputs = []
67
+ for result in results:
68
+ if flatten:
69
+ output = [r["word"] for r in result if self.accept(r["entity_group"], labels) and r["score"] >= threshold]
70
+ outputs.append(" ".join(output) if join else output)
71
+ else:
72
+ outputs.append([(r["word"], r["entity_group"], float(r["score"])) for r in result if self.accept(r["entity_group"], labels)])
73
+
74
+ return outputs[0] if isinstance(text, str) else outputs
75
+
76
+ def isgliner(self, path):
77
+ """
78
+ Tests if path is a GLiNER model.
79
+
80
+ Args:
81
+ path: model path
82
+
83
+ Returns:
84
+ True if this is a GLiNER model, False otherwise
85
+ """
86
+
87
+ try:
88
+ # Test if this model has a gliner_config.json file
89
+ return cached_file(path_or_repo_id=path, filename="gliner_config.json") is not None
90
+
91
+ # Ignore this error - invalid repo or directory
92
+ except (HFValidationError, OSError):
93
+ pass
94
+
95
+ return False
96
+
97
+ def execute(self, text, labels, aggregate, workers):
98
+ """
99
+ Runs the entity extraction pipeline.
100
+
101
+ Args:
102
+ text: text|list
103
+ labels: list of entity type labels to accept, defaults to None which accepts all
104
+ aggregate: method to combine multi token entities - options are "simple" (default), "first", "average" or "max"
105
+ workers: number of concurrent workers to use for processing data, defaults to None
106
+
107
+ Returns:
108
+ list of entities and labels
109
+ """
110
+
111
+ if self.gliner:
112
+ # Extract entities with GLiNER. Use default CoNLL-2003 labels when not otherwise provided.
113
+ results = self.pipeline.batch_predict_entities(
114
+ text if isinstance(text, list) else [text], labels if labels else ["person", "organization", "location"]
115
+ )
116
+
117
+ # Map results to same format as Transformers token classifier
118
+ entities = []
119
+ for result in results:
120
+ entities.append([{"word": x["text"], "entity_group": x["label"], "score": x["score"]} for x in result])
121
+
122
+ # Return extracted entities
123
+ return entities if isinstance(text, list) else entities[0]
124
+
125
+ # Standard Transformers token classification pipeline
126
+ return self.pipeline(text, aggregation_strategy=aggregate, num_workers=workers)
127
+
128
+ def accept(self, etype, labels):
129
+ """
130
+ Determines if entity type is in valid entity type.
131
+
132
+ Args:
133
+ etype: entity type
134
+ labels: list of entities to accept
135
+
136
+ Returns:
137
+ if etype is accepted
138
+ """
139
+
140
+ return not labels or etype in labels
@@ -0,0 +1,137 @@
1
+ """
2
+ Labels module
3
+ """
4
+
5
+ from ..hfpipeline import HFPipeline
6
+
7
+
8
+ class Labels(HFPipeline):
9
+ """
10
+ Applies a text classifier to text. Supports zero shot and standard text classification models
11
+ """
12
+
13
+ def __init__(self, path=None, quantize=False, gpu=True, model=None, dynamic=True, **kwargs):
14
+ super().__init__("zero-shot-classification" if dynamic else "text-classification", path, quantize, gpu, model, **kwargs)
15
+
16
+ # Set if labels are dynamic (zero shot) or fixed (standard text classification)
17
+ self.dynamic = dynamic
18
+
19
+ def __call__(self, text, labels=None, multilabel=False, flatten=None, workers=0, **kwargs):
20
+ """
21
+ Applies a text classifier to text. Returns a list of (id, score) sorted by highest score,
22
+ where id is the index in labels. For zero shot classification, a list of labels is required.
23
+ For text classification models, a list of labels is optional, otherwise all trained labels are returned.
24
+
25
+ This method supports text as a string or a list. If the input is a string, the return
26
+ type is a 1D list of (id, score). If text is a list, a 2D list of (id, score) is
27
+ returned with a row per string.
28
+
29
+ Args:
30
+ text: text|list
31
+ labels: list of labels
32
+ multilabel: labels are independent if True, scores are normalized to sum to 1 per text item if False, raw scores returned if None
33
+ flatten: flatten output to a list of labels if present. Accepts a boolean or float value to only keep scores greater than that number.
34
+ workers: number of concurrent workers to use for processing data, defaults to None
35
+ kwargs: additional keyword args
36
+
37
+ Returns:
38
+ list of (id, score) or list of labels depending on flatten parameter
39
+ """
40
+
41
+ if self.dynamic:
42
+ # Run zero shot classification pipeline
43
+ results = self.pipeline(text, labels, multi_label=multilabel, truncation=True, num_workers=workers)
44
+ else:
45
+ # Set classification function based on inputs
46
+ function = "none" if multilabel is None else "sigmoid" if multilabel or len(self.labels()) == 1 else "softmax"
47
+
48
+ # Run text classification pipeline
49
+ results = self.pipeline(text, top_k=None, function_to_apply=function, num_workers=workers, **kwargs)
50
+
51
+ # Convert results to a list if necessary
52
+ if isinstance(text, str):
53
+ results = [results]
54
+
55
+ # Build list of outputs and return
56
+ outputs = self.outputs(results, labels, flatten)
57
+ return outputs[0] if isinstance(text, str) else outputs
58
+
59
+ def labels(self):
60
+ """
61
+ Returns a list of all text classification model labels sorted in index order.
62
+
63
+ Returns:
64
+ list of labels
65
+ """
66
+
67
+ return list(self.pipeline.model.config.id2label.values())
68
+
69
+ def outputs(self, results, labels, flatten):
70
+ """
71
+ Processes pipeline results and builds outputs.
72
+
73
+ Args:
74
+ results: pipeline results
75
+ labels: list of labels
76
+ flatten: flatten output to a list of labels if present. Accepts a boolean or float value to only keep scores greater than that number.
77
+
78
+ Returns:
79
+ list of outputs
80
+ """
81
+
82
+ outputs = []
83
+ threshold = 0.0 if isinstance(flatten, bool) else flatten
84
+
85
+ for result in results:
86
+ if self.dynamic:
87
+ if flatten:
88
+ result = [label for x, label in enumerate(result["labels"]) if result["scores"][x] >= threshold]
89
+ outputs.append(result[:1] if isinstance(flatten, bool) else result)
90
+ else:
91
+ outputs.append([(labels.index(label), result["scores"][x]) for x, label in enumerate(result["labels"])])
92
+ else:
93
+ if flatten:
94
+ result = [x["label"] for x in result if x["score"] >= threshold and (not labels or x["label"] in labels)]
95
+ outputs.append(result[:1] if isinstance(flatten, bool) else result)
96
+ else:
97
+ # Filter results using labels, if provided
98
+ outputs.append(self.limit(result, labels))
99
+
100
+ return outputs
101
+
102
+ def limit(self, result, labels):
103
+ """
104
+ Filter result using labels. If labels is None, original result is returned.
105
+
106
+ Args:
107
+ result: results array sorted by score descending
108
+ labels: list of labels or None
109
+
110
+ Returns:
111
+ filtered results
112
+ """
113
+
114
+ # Get config
115
+ config = self.pipeline.model.config
116
+
117
+ # Resolve label ids for labels
118
+ result = [(config.label2id.get(x["label"], 0), x["score"]) for x in result]
119
+
120
+ if labels:
121
+ matches = []
122
+ for label in labels:
123
+ # Lookup label keys from model config
124
+ if label.isdigit():
125
+ label = int(label)
126
+ keys = list(config.id2label.keys())
127
+ else:
128
+ label = label.lower()
129
+ keys = [x.lower() for x in config.label2id.keys()]
130
+
131
+ # Find and add label match
132
+ if label in keys:
133
+ matches.append(keys.index(label))
134
+
135
+ return [(label, score) for label, score in result if label in matches]
136
+
137
+ return result
@@ -0,0 +1,103 @@
1
+ """
2
+ Late encoder module
3
+ """
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ...models import Models, PoolingFactory
9
+ from ..base import Pipeline
10
+
11
+
12
+ class LateEncoder(Pipeline):
13
+ """
14
+ Computes similarity between query and list of text using a late interaction model.
15
+ """
16
+
17
+ def __init__(self, path=None, **kwargs):
18
+ # Get device
19
+ self.device = Models.device(Models.deviceid(kwargs.get("gpu", True)))
20
+
21
+ # Load model
22
+ self.model = PoolingFactory.create(
23
+ {
24
+ "method": kwargs.get("method"),
25
+ "path": path if path else "colbert-ir/colbertv2.0",
26
+ "device": self.device,
27
+ "tokenizer": kwargs.get("tokenizer"),
28
+ "maxlength": kwargs.get("maxlength"),
29
+ "modelargs": {**kwargs.get("vectors", {}), **{"muvera": None}},
30
+ }
31
+ )
32
+
33
+ def __call__(self, query, texts, limit=None):
34
+ """
35
+ Computes the similarity between query and list of text. Returns a list of
36
+ (id, score) sorted by highest score, where id is the index in texts.
37
+
38
+ This method supports query as a string or a list. If the input is a string,
39
+ the return type is a 1D list of (id, score). If text is a list, a 2D list
40
+ of (id, score) is returned with a row per string.
41
+
42
+ Args:
43
+ query: query text|list
44
+ texts: list of text
45
+ limit: maximum comparisons to return, defaults to all
46
+
47
+ Returns:
48
+ list of (id, score)
49
+ """
50
+
51
+ queries = [query] if isinstance(query, str) else query
52
+
53
+ # Encode text to vectors
54
+ queries = self.encode(queries, "query")
55
+ data = self.encode(texts, "data") if isinstance(texts[0], str) else texts
56
+
57
+ # Compute maximum similarity score
58
+ scores = []
59
+ for q in queries:
60
+ scores.extend(self.score(q.unsqueeze(0), data, limit))
61
+
62
+ return scores[0] if isinstance(query, str) else scores
63
+
64
+ def encode(self, data, category):
65
+ """
66
+ Encodes a batch of data using the underlying model.
67
+
68
+ Args:
69
+ data: input data
70
+ category: encoding category
71
+
72
+ Returns:
73
+ encoded data
74
+ """
75
+
76
+ return torch.from_numpy(self.model.encode(data, category=category)).to(self.device)
77
+
78
+ def score(self, queries, data, limit):
79
+ """
80
+ Computes the maximum similarity score between query vectors and data vectors.
81
+
82
+ Args:
83
+ queries: query vectors
84
+ data: data vectors
85
+ limit: query limit
86
+
87
+ Returns:
88
+ list of (id, score)
89
+ """
90
+
91
+ # Compute bulk dot product using einstein notation
92
+ scores = torch.einsum("ash,bth->abst", queries, data).max(axis=-1).values.mean(axis=-1)
93
+ scores = scores.cpu().numpy()
94
+
95
+ # Get top n matching indices and scores
96
+ indices = np.argpartition(-scores, limit if limit and limit < scores.shape[0] else scores.shape[0] - 1)[:, :limit]
97
+ scores = np.take_along_axis(scores, indices, axis=1)
98
+
99
+ results = []
100
+ for x, index in enumerate(indices):
101
+ results.append(list(zip(index.tolist(), scores[x].tolist())))
102
+
103
+ return results
@@ -0,0 +1,48 @@
1
+ """
2
+ Questions module
3
+ """
4
+
5
+ from ..hfpipeline import HFPipeline
6
+
7
+
8
+ class Questions(HFPipeline):
9
+ """
10
+ Runs extractive QA for a series of questions and contexts.
11
+ """
12
+
13
+ def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
14
+ super().__init__("question-answering", path, quantize, gpu, model, **kwargs)
15
+
16
+ def __call__(self, questions, contexts, workers=0):
17
+ """
18
+ Runs a extractive question-answering model against each question-context pair, finding the best answers.
19
+
20
+ Args:
21
+ questions: list of questions
22
+ contexts: list of contexts to pull answers from
23
+ workers: number of concurrent workers to use for processing data, defaults to None
24
+
25
+ Returns:
26
+ list of answers
27
+ """
28
+
29
+ answers = []
30
+
31
+ for x, question in enumerate(questions):
32
+ if question and contexts[x]:
33
+ # Run the QA pipeline
34
+ result = self.pipeline(question=question, context=contexts[x], num_workers=workers)
35
+
36
+ # Get answer and score
37
+ answer, score = result["answer"], result["score"]
38
+
39
+ # Require score to be at least 0.05
40
+ if score < 0.05:
41
+ answer = None
42
+
43
+ # Add answer
44
+ answers.append(answer)
45
+ else:
46
+ answers.append(None)
47
+
48
+ return answers
@@ -0,0 +1,57 @@
1
+ """
2
+ Reranker module
3
+ """
4
+
5
+ from ..base import Pipeline
6
+
7
+
8
+ class Reranker(Pipeline):
9
+ """
10
+ Runs embeddings queries and re-ranks them using a similarity pipeline. Note that content must be enabled with the
11
+ embeddings instance for this to work properly.
12
+ """
13
+
14
+ def __init__(self, embeddings, similarity):
15
+ """
16
+ Creates a Reranker pipeline.
17
+
18
+ Args:
19
+ embeddings: embeddings instance (content must be enabled)
20
+ similarity: similarity instance
21
+ """
22
+
23
+ self.embeddings, self.similarity = embeddings, similarity
24
+
25
+ # pylint: disable=W0222
26
+ def __call__(self, query, limit=3, factor=10, **kwargs):
27
+ """
28
+ Runs an embeddings search and re-ranks the results using a Similarity pipeline.
29
+
30
+ Args:
31
+ query: query text|list
32
+ limit: maximum results
33
+ factor: factor to multiply limit by for the initial embeddings search
34
+ kwargs: additional arguments to pass to embeddings search
35
+
36
+ Returns:
37
+ list of query results rescored using a Similarity pipeline
38
+ """
39
+
40
+ queries = [query] if not isinstance(query, list) else query
41
+
42
+ # Run searches
43
+ results = self.embeddings.batchsearch(queries, limit * factor, **kwargs)
44
+
45
+ # Re-rank using similarity pipeline
46
+ ranked = []
47
+ for x, result in enumerate(results):
48
+ texts = [row["text"] for row in result]
49
+
50
+ # Score results and merge
51
+ for uid, score in self.similarity(queries[x], texts):
52
+ result[uid]["score"] = score
53
+
54
+ # Sort and take top n sorted results
55
+ ranked.append(sorted(result, key=lambda row: row["score"], reverse=True)[:limit])
56
+
57
+ return ranked[0] if isinstance(query, str) else ranked
@@ -0,0 +1,83 @@
1
+ """
2
+ Similarity module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ from .crossencoder import CrossEncoder
8
+ from .labels import Labels
9
+ from .lateencoder import LateEncoder
10
+
11
+
12
+ class Similarity(Labels):
13
+ """
14
+ Computes similarity between query and list of text using a transformers model.
15
+ """
16
+
17
+ def __init__(self, path=None, quantize=False, gpu=True, model=None, dynamic=True, crossencode=False, lateencode=False, **kwargs):
18
+ self.crossencoder, self.lateencoder = None, None
19
+
20
+ if lateencode:
21
+ # Load a late interaction encoder if lateencode set to True
22
+ self.lateencoder = LateEncoder(path=path, gpu=gpu, **kwargs)
23
+ else:
24
+ # Use zero-shot classification if dynamic is True and crossencode is False, otherwise use standard text classification
25
+ super().__init__(path, quantize, gpu, model, False if crossencode else dynamic, **kwargs)
26
+
27
+ # Load as a cross-encoder if crossencode set to True
28
+ self.crossencoder = CrossEncoder(model=self.pipeline) if crossencode else None
29
+
30
+ # pylint: disable=W0222
31
+ def __call__(self, query, texts, multilabel=True, **kwargs):
32
+ """
33
+ Computes the similarity between query and list of text. Returns a list of
34
+ (id, score) sorted by highest score, where id is the index in texts.
35
+
36
+ This method supports query as a string or a list. If the input is a string,
37
+ the return type is a 1D list of (id, score). If text is a list, a 2D list
38
+ of (id, score) is returned with a row per string.
39
+
40
+ Args:
41
+ query: query text|list
42
+ texts: list of text
43
+ multilabel: labels are independent if True, scores are normalized to sum to 1 per text item if False, raw scores returned if None
44
+ kwargs: additional keyword args
45
+
46
+ Returns:
47
+ list of (id, score)
48
+ """
49
+
50
+ if self.crossencoder:
51
+ # pylint: disable=E1102
52
+ return self.crossencoder(query, texts, multilabel)
53
+
54
+ if self.lateencoder:
55
+ return self.lateencoder(query, texts)
56
+
57
+ # Call Labels pipeline for texts using input query as the candidate label
58
+ scores = super().__call__(texts, [query] if isinstance(query, str) else query, multilabel, **kwargs)
59
+
60
+ # Sort on query index id
61
+ scores = [[score for _, score in sorted(row)] for row in scores]
62
+
63
+ # Transpose axes to get a list of text scores for each query
64
+ scores = np.array(scores).T.tolist()
65
+
66
+ # Build list of (id, score) per query sorted by highest score
67
+ scores = [sorted(enumerate(row), key=lambda x: x[1], reverse=True) for row in scores]
68
+
69
+ return scores[0] if isinstance(query, str) else scores
70
+
71
+ def encode(self, data, category):
72
+ """
73
+ Encodes a batch of data using the underlying model.
74
+
75
+ Args:
76
+ data: input data
77
+ category: encoding category
78
+
79
+ Returns:
80
+ encoded data
81
+ """
82
+
83
+ return self.lateencoder.encode(data, category) if self.lateencoder else data
@@ -0,0 +1,98 @@
1
+ """
2
+ Summary module
3
+ """
4
+
5
+ import re
6
+
7
+ from ..hfpipeline import HFPipeline
8
+
9
+
10
+ class Summary(HFPipeline):
11
+ """
12
+ Summarizes text.
13
+ """
14
+
15
+ def __init__(self, path=None, quantize=False, gpu=True, model=None, **kwargs):
16
+ super().__init__("summarization", path, quantize, gpu, model, **kwargs)
17
+
18
+ def __call__(self, text, minlength=None, maxlength=None, workers=0):
19
+ """
20
+ Runs a summarization model against a block of text.
21
+
22
+ This method supports text as a string or a list. If the input is a string, the return
23
+ type is text. If text is a list, a list of text is returned with a row per block of text.
24
+
25
+ Args:
26
+ text: text|list
27
+ minlength: minimum length for summary
28
+ maxlength: maximum length for summary
29
+ workers: number of concurrent workers to use for processing data, defaults to None
30
+
31
+ Returns:
32
+ summary text
33
+ """
34
+
35
+ # Validate text length greater than max length
36
+ check = maxlength if maxlength else self.maxlength()
37
+
38
+ # Skip text shorter than max length
39
+ texts = text if isinstance(text, list) else [text]
40
+ params = [(x, text if len(text) >= check else None) for x, text in enumerate(texts)]
41
+
42
+ # Build keyword arguments
43
+ kwargs = self.args(minlength, maxlength)
44
+
45
+ inputs = [text for _, text in params if text]
46
+ if inputs:
47
+ # Run summarization pipeline
48
+ results = self.pipeline(inputs, num_workers=workers, **kwargs)
49
+
50
+ # Pull out summary text
51
+ results = iter([self.clean(x["summary_text"]) for x in results])
52
+ results = [next(results) if text else texts[x] for x, text in params]
53
+ else:
54
+ # Return original
55
+ results = texts
56
+
57
+ return results[0] if isinstance(text, str) else results
58
+
59
+ def clean(self, text):
60
+ """
61
+ Applies a series of rules to clean extracted text.
62
+
63
+ Args:
64
+ text: input text
65
+
66
+ Returns:
67
+ clean text
68
+ """
69
+
70
+ text = re.sub(r"\s*\.\s*", ". ", text)
71
+ text = text.strip()
72
+
73
+ return text
74
+
75
+ def args(self, minlength, maxlength):
76
+ """
77
+ Builds keyword arguments.
78
+
79
+ Args:
80
+ minlength: minimum length for summary
81
+ maxlength: maximum length for summary
82
+
83
+ Returns:
84
+ keyword arguments
85
+ """
86
+
87
+ kwargs = {"truncation": True}
88
+ if minlength:
89
+ kwargs["min_length"] = minlength
90
+ if maxlength:
91
+ kwargs["max_length"] = maxlength
92
+ kwargs["max_new_tokens"] = None
93
+
94
+ # Default minlength if not provided or it's bigger than maxlength
95
+ if "min_length" not in kwargs or kwargs["min_length"] > kwargs["max_length"]:
96
+ kwargs["min_length"] = kwargs["max_length"]
97
+
98
+ return kwargs