mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,22 @@
|
|
1
|
+
"""
|
2
|
+
Main module.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import sys
|
6
|
+
|
7
|
+
from .base import Console
|
8
|
+
|
9
|
+
|
10
|
+
def main(path=None):
|
11
|
+
"""
|
12
|
+
Console execution loop.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
path: model path
|
16
|
+
"""
|
17
|
+
|
18
|
+
Console(path).cmdloop()
|
19
|
+
|
20
|
+
|
21
|
+
if __name__ == "__main__":
|
22
|
+
main(sys.argv[1] if len(sys.argv) > 1 else None)
|
txtai/console/base.py
ADDED
@@ -0,0 +1,264 @@
|
|
1
|
+
"""
|
2
|
+
Console module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import shlex
|
7
|
+
|
8
|
+
from cmd import Cmd
|
9
|
+
|
10
|
+
# Conditional import
|
11
|
+
try:
|
12
|
+
from rich import box
|
13
|
+
from rich.console import Console as RichConsole
|
14
|
+
from rich.table import Table
|
15
|
+
|
16
|
+
RICH = True
|
17
|
+
except ImportError:
|
18
|
+
RICH = False
|
19
|
+
|
20
|
+
from txtai.app import Application
|
21
|
+
from txtai.embeddings import Embeddings
|
22
|
+
|
23
|
+
|
24
|
+
class Console(Cmd):
|
25
|
+
"""
|
26
|
+
txtai console.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, path=None):
|
30
|
+
"""
|
31
|
+
Creates a new command line console.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
path: path to initial configuration, if any
|
35
|
+
"""
|
36
|
+
|
37
|
+
super().__init__()
|
38
|
+
|
39
|
+
if not RICH:
|
40
|
+
raise ImportError('Console is not available - install "console" extra to enable')
|
41
|
+
|
42
|
+
self.prompt = ">>> "
|
43
|
+
|
44
|
+
# Rich console
|
45
|
+
self.console = RichConsole()
|
46
|
+
|
47
|
+
# App parameters
|
48
|
+
self.app = None
|
49
|
+
self.path = path
|
50
|
+
|
51
|
+
# Parameters
|
52
|
+
self.vhighlight = None
|
53
|
+
self.vlimit = None
|
54
|
+
|
55
|
+
def preloop(self):
|
56
|
+
"""
|
57
|
+
Loads initial configuration.
|
58
|
+
"""
|
59
|
+
|
60
|
+
self.console.print("txtai console", style="#03a9f4")
|
61
|
+
|
62
|
+
# Load default path
|
63
|
+
if self.path:
|
64
|
+
self.load(self.path)
|
65
|
+
|
66
|
+
def default(self, line):
|
67
|
+
"""
|
68
|
+
Default event loop.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
line: command line
|
72
|
+
"""
|
73
|
+
|
74
|
+
# pylint: disable=W0703
|
75
|
+
try:
|
76
|
+
command = line.lower()
|
77
|
+
if command.startswith(".config"):
|
78
|
+
self.config()
|
79
|
+
elif command.startswith(".highlight"):
|
80
|
+
self.highlight(command)
|
81
|
+
elif command.startswith(".limit"):
|
82
|
+
self.limit(command)
|
83
|
+
elif command.startswith(".load"):
|
84
|
+
command = self.split(line)
|
85
|
+
self.path = command[1]
|
86
|
+
self.load(self.path)
|
87
|
+
elif command.startswith(".workflow"):
|
88
|
+
self.workflow(line)
|
89
|
+
else:
|
90
|
+
# Search is default action
|
91
|
+
self.search(line)
|
92
|
+
except Exception:
|
93
|
+
self.console.print_exception()
|
94
|
+
|
95
|
+
def config(self):
|
96
|
+
"""
|
97
|
+
Processes .config command.
|
98
|
+
"""
|
99
|
+
|
100
|
+
self.console.print(self.app.config)
|
101
|
+
|
102
|
+
def highlight(self, command):
|
103
|
+
"""
|
104
|
+
Processes .highlight command.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
command: command line
|
108
|
+
"""
|
109
|
+
|
110
|
+
_, action = self.split(command, "#ffff00")
|
111
|
+
self.vhighlight = action
|
112
|
+
self.console.print(f"Set highlight to {self.vhighlight}")
|
113
|
+
|
114
|
+
def limit(self, command):
|
115
|
+
"""
|
116
|
+
Processes .limit command.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
command: command line
|
120
|
+
"""
|
121
|
+
|
122
|
+
_, action = self.split(command, 10)
|
123
|
+
self.vlimit = int(action)
|
124
|
+
self.console.print(f"Set limit to {self.vlimit}")
|
125
|
+
|
126
|
+
def load(self, path):
|
127
|
+
"""
|
128
|
+
Processes .load command.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
path: path to configuration
|
132
|
+
"""
|
133
|
+
|
134
|
+
if self.isyaml(path):
|
135
|
+
self.console.print(f"Loading application {path}")
|
136
|
+
self.app = Application(path)
|
137
|
+
else:
|
138
|
+
self.console.print(f"Loading index {path}")
|
139
|
+
|
140
|
+
# Load embeddings index
|
141
|
+
self.app = Embeddings()
|
142
|
+
self.app.load(path)
|
143
|
+
|
144
|
+
def search(self, query):
|
145
|
+
"""
|
146
|
+
Runs a search query.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
query: query to run
|
150
|
+
"""
|
151
|
+
|
152
|
+
if self.vhighlight:
|
153
|
+
results = self.app.explain(query, limit=self.vlimit)
|
154
|
+
else:
|
155
|
+
results = self.app.search(query, limit=self.vlimit)
|
156
|
+
|
157
|
+
columns, table = {}, Table(box=box.SQUARE, style="#03a9f4")
|
158
|
+
|
159
|
+
# Build column list
|
160
|
+
result = results[0]
|
161
|
+
if isinstance(result, tuple):
|
162
|
+
columns = dict.fromkeys(["id", "score"])
|
163
|
+
else:
|
164
|
+
columns = dict(result)
|
165
|
+
|
166
|
+
# Add columns to table
|
167
|
+
columns = list(x for x in columns if x != "tokens")
|
168
|
+
for column in columns:
|
169
|
+
table.add_column(column)
|
170
|
+
|
171
|
+
# Add rows to table
|
172
|
+
for result in results:
|
173
|
+
if isinstance(result, tuple):
|
174
|
+
table.add_row(*(self.render(result, None, x) for x in result))
|
175
|
+
else:
|
176
|
+
table.add_row(*(self.render(result, column, result.get(column)) for column in columns))
|
177
|
+
|
178
|
+
# Print table to console
|
179
|
+
self.console.print(table)
|
180
|
+
|
181
|
+
def workflow(self, command):
|
182
|
+
"""
|
183
|
+
Processes .workflow command.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
command: command line
|
187
|
+
"""
|
188
|
+
|
189
|
+
command = shlex.split(command)
|
190
|
+
if isinstance(self.app, Application):
|
191
|
+
self.console.print(list(self.app.workflow(command[1], command[2:])))
|
192
|
+
|
193
|
+
def isyaml(self, path):
|
194
|
+
"""
|
195
|
+
Checks if file at path is a valid YAML file.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
path: file to check
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
True if file is valid YAML, False otherwise
|
202
|
+
"""
|
203
|
+
|
204
|
+
if os.path.exists(path) and os.path.isfile(path):
|
205
|
+
try:
|
206
|
+
return Application.read(path)
|
207
|
+
# pylint: disable=W0702
|
208
|
+
except:
|
209
|
+
pass
|
210
|
+
|
211
|
+
return False
|
212
|
+
|
213
|
+
def split(self, command, default=None):
|
214
|
+
"""
|
215
|
+
Splits command by whitespace.
|
216
|
+
|
217
|
+
Args:
|
218
|
+
command: command line
|
219
|
+
default: default command action
|
220
|
+
|
221
|
+
Returns:
|
222
|
+
command action
|
223
|
+
"""
|
224
|
+
|
225
|
+
values = command.split(" ", 1)
|
226
|
+
return values if len(values) > 1 else (command, default)
|
227
|
+
|
228
|
+
def render(self, result, column, value):
|
229
|
+
"""
|
230
|
+
Renders a search result column value.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
result: result row
|
234
|
+
column: column name
|
235
|
+
value: column value
|
236
|
+
"""
|
237
|
+
|
238
|
+
if isinstance(value, float):
|
239
|
+
return f"{value:.4f}"
|
240
|
+
|
241
|
+
# Explain highlighting
|
242
|
+
if column == "text" and "tokens" in result:
|
243
|
+
spans = []
|
244
|
+
for token, score in result["tokens"]:
|
245
|
+
color = None
|
246
|
+
if score >= 0.02:
|
247
|
+
color = f"b {self.vhighlight}"
|
248
|
+
|
249
|
+
spans.append((token, score, color))
|
250
|
+
|
251
|
+
if result["score"] >= 0.05 and not [color for _, _, color in spans if color]:
|
252
|
+
mscore = max(score for _, score, _ in spans)
|
253
|
+
spans = [(token, score, f"b {self.vhighlight}" if score == mscore else color) for token, score, color in spans]
|
254
|
+
|
255
|
+
output = ""
|
256
|
+
for token, _, color in spans:
|
257
|
+
if color:
|
258
|
+
output += f"[{color}]{token}[/{color}] "
|
259
|
+
else:
|
260
|
+
output += f"{token} "
|
261
|
+
|
262
|
+
return output
|
263
|
+
|
264
|
+
return str(value)
|
txtai/data/__init__.py
ADDED
txtai/data/base.py
ADDED
@@ -0,0 +1,138 @@
|
|
1
|
+
"""
|
2
|
+
Data module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .tokens import Tokens
|
6
|
+
|
7
|
+
|
8
|
+
class Data:
|
9
|
+
"""
|
10
|
+
Base data tokenization class.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, tokenizer, columns, maxlength):
|
14
|
+
"""
|
15
|
+
Creates new base instance for tokenizing data.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
tokenizer: model tokenizer
|
19
|
+
columns: column names
|
20
|
+
maxlength: maximum sequence length
|
21
|
+
"""
|
22
|
+
|
23
|
+
self.tokenizer = tokenizer
|
24
|
+
self.columns = columns
|
25
|
+
self.maxlength = maxlength
|
26
|
+
|
27
|
+
def __call__(self, train, validation, workers):
|
28
|
+
"""
|
29
|
+
Tokenizes training and validation data and returns processed datasets.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
train: training data
|
33
|
+
validation: validation data
|
34
|
+
workers: number of concurrent tokenizers when processing datasets, only main process used when set to None
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
(train, validation)
|
38
|
+
"""
|
39
|
+
|
40
|
+
return (self.prepare(train, self.process, workers), self.prepare(validation, self.process, workers) if validation else None)
|
41
|
+
|
42
|
+
def prepare(self, data, fn, workers):
|
43
|
+
"""
|
44
|
+
Prepares and tokenizes data for training.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
data: input data
|
48
|
+
fn: tokenize processing function to apply
|
49
|
+
workers: number of concurrent tokenizers when processing datasets, only main process used when set to None
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
tokens
|
53
|
+
"""
|
54
|
+
|
55
|
+
if hasattr(data, "map"):
|
56
|
+
# Hugging Face dataset
|
57
|
+
tokens = data.map(fn, batched=True, num_proc=workers, remove_columns=data.column_names)
|
58
|
+
else:
|
59
|
+
# Re-orient data into columns for efficient batch tokenization
|
60
|
+
columns = {}
|
61
|
+
if hasattr(data, "columns"):
|
62
|
+
# Polars/pandas DataFrame
|
63
|
+
for column in data.columns:
|
64
|
+
columns[column] = list(data[column])
|
65
|
+
else:
|
66
|
+
# Iterable dicts
|
67
|
+
for row in data:
|
68
|
+
for column in row.keys():
|
69
|
+
if column not in columns:
|
70
|
+
columns[column] = []
|
71
|
+
|
72
|
+
columns[column].append(row[column])
|
73
|
+
|
74
|
+
# Process column-oriented data
|
75
|
+
tokens = Tokens(fn(columns))
|
76
|
+
|
77
|
+
return tokens
|
78
|
+
|
79
|
+
def labels(self, data):
|
80
|
+
"""
|
81
|
+
Extracts a list of unique labels from data.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
data: input data
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
list of unique labels
|
88
|
+
"""
|
89
|
+
|
90
|
+
# Last column is label
|
91
|
+
column = self.columns[-1]
|
92
|
+
|
93
|
+
# Return length of labels if it's an array
|
94
|
+
length = self.length(data[column][0] if hasattr(data, "columns") else data[0][column])
|
95
|
+
if length:
|
96
|
+
return length
|
97
|
+
|
98
|
+
if hasattr(data, "map"):
|
99
|
+
# Hugging Face dataset
|
100
|
+
labels = sorted(data.unique(self.columns[-1]))
|
101
|
+
elif hasattr(data, "columns"):
|
102
|
+
# Polars/pandas DataFrame
|
103
|
+
labels = sorted(data[self.columns[-1]].unique())
|
104
|
+
else:
|
105
|
+
# Iterable dicts
|
106
|
+
labels = sorted({row[self.columns[-1]] for row in data})
|
107
|
+
|
108
|
+
# Labels are single numeric values per entry
|
109
|
+
# - Consider a regression task if at least one label isn't an integer
|
110
|
+
# - Otherwise use number of labels for a classification task
|
111
|
+
return 1 if [x for x in labels if float(x) != int(x)] else len(labels)
|
112
|
+
|
113
|
+
def process(self, data):
|
114
|
+
"""
|
115
|
+
Tokenizes batch of input data
|
116
|
+
|
117
|
+
Args:
|
118
|
+
data: input data batch
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
tokenized data
|
122
|
+
"""
|
123
|
+
|
124
|
+
return data
|
125
|
+
|
126
|
+
def length(self, value):
|
127
|
+
"""
|
128
|
+
Returns the length of value if value has a len function defined. Otherwise,
|
129
|
+
None is returned.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
value: value to check
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
length of value if available, otherwise returns None
|
136
|
+
"""
|
137
|
+
|
138
|
+
return len(value) if hasattr(value, "__len__") else None
|
txtai/data/labels.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
"""
|
2
|
+
Labels module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Data
|
6
|
+
|
7
|
+
|
8
|
+
class Labels(Data):
|
9
|
+
"""
|
10
|
+
Tokenizes text-classification datasets as input for training text-classification models.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, tokenizer, columns, maxlength):
|
14
|
+
"""
|
15
|
+
Creates a new instance for tokenizing Labels training data.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
tokenizer: model tokenizer
|
19
|
+
columns: tuple of columns to use for text/label
|
20
|
+
maxlength: maximum sequence length
|
21
|
+
"""
|
22
|
+
|
23
|
+
super().__init__(tokenizer, columns, maxlength)
|
24
|
+
|
25
|
+
# Standardize columns
|
26
|
+
if not self.columns:
|
27
|
+
self.columns = ("text", None, "label")
|
28
|
+
elif len(columns) < 3:
|
29
|
+
self.columns = (self.columns[0], None, self.columns[-1])
|
30
|
+
|
31
|
+
def process(self, data):
|
32
|
+
# Column keys
|
33
|
+
text1, text2, label = self.columns
|
34
|
+
|
35
|
+
# Tokenizer inputs can be single string or string pair, depending on task
|
36
|
+
text = (data[text1], data[text2]) if text2 else (data[text1],)
|
37
|
+
|
38
|
+
# Tokenize text and add label
|
39
|
+
inputs = self.tokenizer(*text, max_length=self.maxlength, padding=True, truncation=True)
|
40
|
+
inputs[label] = data[label]
|
41
|
+
|
42
|
+
return inputs
|
txtai/data/questions.py
ADDED
@@ -0,0 +1,135 @@
|
|
1
|
+
"""
|
2
|
+
Questions module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Data
|
6
|
+
|
7
|
+
|
8
|
+
class Questions(Data):
|
9
|
+
"""
|
10
|
+
Tokenizes question-answering datasets as input for training question-answering models.
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, tokenizer, columns, maxlength, stride):
|
14
|
+
"""
|
15
|
+
Creates a new instance for tokenizing Questions training data.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
tokenizer: model tokenizer
|
19
|
+
columns: tuple of columns to use for question/context/answer
|
20
|
+
maxlength: maximum sequence length
|
21
|
+
stride: chunk size for splitting data for QA tasks
|
22
|
+
"""
|
23
|
+
|
24
|
+
super().__init__(tokenizer, columns, maxlength)
|
25
|
+
|
26
|
+
if not self.columns:
|
27
|
+
self.columns = ("question", "context", "answers")
|
28
|
+
|
29
|
+
self.question, self.context, self.answer = self.columns
|
30
|
+
self.stride = stride
|
31
|
+
self.rpad = tokenizer.padding_side == "right"
|
32
|
+
|
33
|
+
def process(self, data):
|
34
|
+
# Tokenize data
|
35
|
+
tokenized = self.tokenize(data)
|
36
|
+
|
37
|
+
# Get mapping of overflowing tokens and answer offsets
|
38
|
+
samples = tokenized.pop("overflow_to_sample_mapping")
|
39
|
+
offsets = tokenized.pop("offset_mapping")
|
40
|
+
|
41
|
+
# Start/end positions
|
42
|
+
tokenized["start_positions"] = []
|
43
|
+
tokenized["end_positions"] = []
|
44
|
+
|
45
|
+
for x, offset in enumerate(offsets):
|
46
|
+
# Label NO ANSWER with CLS token
|
47
|
+
inputids = tokenized["input_ids"][x]
|
48
|
+
clstoken = inputids.index(self.tokenizer.cls_token_id)
|
49
|
+
|
50
|
+
# Sequence ids
|
51
|
+
sequences = tokenized.sequence_ids(x)
|
52
|
+
|
53
|
+
# Get and format answer
|
54
|
+
answers = self.answers(data, samples[x])
|
55
|
+
|
56
|
+
# If no answers are given, set cls token as answer.
|
57
|
+
if len(answers["answer_start"]) == 0:
|
58
|
+
tokenized["start_positions"].append(clstoken)
|
59
|
+
tokenized["end_positions"].append(clstoken)
|
60
|
+
else:
|
61
|
+
# Start/end character index of the answer in the text.
|
62
|
+
startchar = answers["answer_start"][0]
|
63
|
+
endchar = startchar + len(answers["text"][0])
|
64
|
+
|
65
|
+
# Start token index of the current span in the text.
|
66
|
+
start = 0
|
67
|
+
while sequences[start] != (1 if self.rpad else 0):
|
68
|
+
start += 1
|
69
|
+
|
70
|
+
# End token index of the current span in the text.
|
71
|
+
end = len(inputids) - 1
|
72
|
+
while sequences[end] != (1 if self.rpad else 0):
|
73
|
+
end -= 1
|
74
|
+
|
75
|
+
# Map start character and end character to matching token index
|
76
|
+
while start < len(offset) and offset[start][0] <= startchar:
|
77
|
+
start += 1
|
78
|
+
tokenized["start_positions"].append(start - 1)
|
79
|
+
|
80
|
+
while offset[end][1] >= endchar:
|
81
|
+
end -= 1
|
82
|
+
tokenized["end_positions"].append(end + 1)
|
83
|
+
|
84
|
+
return tokenized
|
85
|
+
|
86
|
+
def tokenize(self, data):
|
87
|
+
"""
|
88
|
+
Tokenizes batch of data
|
89
|
+
|
90
|
+
Args:
|
91
|
+
data: input data batch
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
tokenized data
|
95
|
+
"""
|
96
|
+
|
97
|
+
# Trim question whitespace
|
98
|
+
data[self.question] = [x.lstrip() for x in data[self.question]]
|
99
|
+
|
100
|
+
# Tokenize records
|
101
|
+
return self.tokenizer(
|
102
|
+
data[self.question if self.rpad else self.context],
|
103
|
+
data[self.context if self.rpad else self.question],
|
104
|
+
truncation="only_second" if self.rpad else "only_first",
|
105
|
+
max_length=self.maxlength,
|
106
|
+
stride=self.stride,
|
107
|
+
return_overflowing_tokens=True,
|
108
|
+
return_offsets_mapping=True,
|
109
|
+
padding=True,
|
110
|
+
)
|
111
|
+
|
112
|
+
def answers(self, data, index):
|
113
|
+
"""
|
114
|
+
Gets and formats an answer.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
data: input examples
|
118
|
+
index: answer index to retrieve
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
answers dict
|
122
|
+
"""
|
123
|
+
|
124
|
+
# Answer mappings
|
125
|
+
answers = data[self.answer][index]
|
126
|
+
context = data[self.context][index]
|
127
|
+
|
128
|
+
# Handle mapping string answers to dict
|
129
|
+
if not isinstance(answers, dict):
|
130
|
+
if not answers:
|
131
|
+
answers = {"text": [], "answer_start": []}
|
132
|
+
else:
|
133
|
+
answers = {"text": [answers], "answer_start": [context.index(answers)]}
|
134
|
+
|
135
|
+
return answers
|
txtai/data/sequences.py
ADDED
@@ -0,0 +1,48 @@
|
|
1
|
+
"""
|
2
|
+
Sequences module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Data
|
6
|
+
|
7
|
+
|
8
|
+
class Sequences(Data):
|
9
|
+
"""
|
10
|
+
Tokenizes sequence-sequence datasets as input for training sequence-sequence models
|
11
|
+
"""
|
12
|
+
|
13
|
+
def __init__(self, tokenizer, columns, maxlength, prefix):
|
14
|
+
"""
|
15
|
+
Creates a new instance for tokenizing Sequences training data.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
tokenizer: model tokenizer
|
19
|
+
columns: tuple of columns to use for text/label
|
20
|
+
maxlength: maximum sequence length
|
21
|
+
prefix: source prefix
|
22
|
+
"""
|
23
|
+
|
24
|
+
super().__init__(tokenizer, columns, maxlength)
|
25
|
+
|
26
|
+
# Standardize columns
|
27
|
+
if not self.columns:
|
28
|
+
self.columns = ("source", "target")
|
29
|
+
|
30
|
+
# Save source prefix
|
31
|
+
self.prefix = prefix
|
32
|
+
|
33
|
+
def process(self, data):
|
34
|
+
# Column keys
|
35
|
+
source, target = self.columns
|
36
|
+
|
37
|
+
# Tokenize source
|
38
|
+
source = [self.prefix + x if self.prefix else x for x in data[source]]
|
39
|
+
inputs = self.tokenizer(source, max_length=self.maxlength, padding=False, truncation=True)
|
40
|
+
|
41
|
+
# Tokenize target
|
42
|
+
with self.tokenizer.as_target_tokenizer():
|
43
|
+
targets = self.tokenizer(data[target], max_length=self.maxlength, padding=False, truncation=True)
|
44
|
+
|
45
|
+
# Combine inputs
|
46
|
+
inputs["labels"] = targets["input_ids"]
|
47
|
+
|
48
|
+
return inputs
|