mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,22 @@
1
+ """
2
+ Main module.
3
+ """
4
+
5
+ import sys
6
+
7
+ from .base import Console
8
+
9
+
10
+ def main(path=None):
11
+ """
12
+ Console execution loop.
13
+
14
+ Args:
15
+ path: model path
16
+ """
17
+
18
+ Console(path).cmdloop()
19
+
20
+
21
+ if __name__ == "__main__":
22
+ main(sys.argv[1] if len(sys.argv) > 1 else None)
txtai/console/base.py ADDED
@@ -0,0 +1,264 @@
1
+ """
2
+ Console module
3
+ """
4
+
5
+ import os
6
+ import shlex
7
+
8
+ from cmd import Cmd
9
+
10
+ # Conditional import
11
+ try:
12
+ from rich import box
13
+ from rich.console import Console as RichConsole
14
+ from rich.table import Table
15
+
16
+ RICH = True
17
+ except ImportError:
18
+ RICH = False
19
+
20
+ from txtai.app import Application
21
+ from txtai.embeddings import Embeddings
22
+
23
+
24
+ class Console(Cmd):
25
+ """
26
+ txtai console.
27
+ """
28
+
29
+ def __init__(self, path=None):
30
+ """
31
+ Creates a new command line console.
32
+
33
+ Args:
34
+ path: path to initial configuration, if any
35
+ """
36
+
37
+ super().__init__()
38
+
39
+ if not RICH:
40
+ raise ImportError('Console is not available - install "console" extra to enable')
41
+
42
+ self.prompt = ">>> "
43
+
44
+ # Rich console
45
+ self.console = RichConsole()
46
+
47
+ # App parameters
48
+ self.app = None
49
+ self.path = path
50
+
51
+ # Parameters
52
+ self.vhighlight = None
53
+ self.vlimit = None
54
+
55
+ def preloop(self):
56
+ """
57
+ Loads initial configuration.
58
+ """
59
+
60
+ self.console.print("txtai console", style="#03a9f4")
61
+
62
+ # Load default path
63
+ if self.path:
64
+ self.load(self.path)
65
+
66
+ def default(self, line):
67
+ """
68
+ Default event loop.
69
+
70
+ Args:
71
+ line: command line
72
+ """
73
+
74
+ # pylint: disable=W0703
75
+ try:
76
+ command = line.lower()
77
+ if command.startswith(".config"):
78
+ self.config()
79
+ elif command.startswith(".highlight"):
80
+ self.highlight(command)
81
+ elif command.startswith(".limit"):
82
+ self.limit(command)
83
+ elif command.startswith(".load"):
84
+ command = self.split(line)
85
+ self.path = command[1]
86
+ self.load(self.path)
87
+ elif command.startswith(".workflow"):
88
+ self.workflow(line)
89
+ else:
90
+ # Search is default action
91
+ self.search(line)
92
+ except Exception:
93
+ self.console.print_exception()
94
+
95
+ def config(self):
96
+ """
97
+ Processes .config command.
98
+ """
99
+
100
+ self.console.print(self.app.config)
101
+
102
+ def highlight(self, command):
103
+ """
104
+ Processes .highlight command.
105
+
106
+ Args:
107
+ command: command line
108
+ """
109
+
110
+ _, action = self.split(command, "#ffff00")
111
+ self.vhighlight = action
112
+ self.console.print(f"Set highlight to {self.vhighlight}")
113
+
114
+ def limit(self, command):
115
+ """
116
+ Processes .limit command.
117
+
118
+ Args:
119
+ command: command line
120
+ """
121
+
122
+ _, action = self.split(command, 10)
123
+ self.vlimit = int(action)
124
+ self.console.print(f"Set limit to {self.vlimit}")
125
+
126
+ def load(self, path):
127
+ """
128
+ Processes .load command.
129
+
130
+ Args:
131
+ path: path to configuration
132
+ """
133
+
134
+ if self.isyaml(path):
135
+ self.console.print(f"Loading application {path}")
136
+ self.app = Application(path)
137
+ else:
138
+ self.console.print(f"Loading index {path}")
139
+
140
+ # Load embeddings index
141
+ self.app = Embeddings()
142
+ self.app.load(path)
143
+
144
+ def search(self, query):
145
+ """
146
+ Runs a search query.
147
+
148
+ Args:
149
+ query: query to run
150
+ """
151
+
152
+ if self.vhighlight:
153
+ results = self.app.explain(query, limit=self.vlimit)
154
+ else:
155
+ results = self.app.search(query, limit=self.vlimit)
156
+
157
+ columns, table = {}, Table(box=box.SQUARE, style="#03a9f4")
158
+
159
+ # Build column list
160
+ result = results[0]
161
+ if isinstance(result, tuple):
162
+ columns = dict.fromkeys(["id", "score"])
163
+ else:
164
+ columns = dict(result)
165
+
166
+ # Add columns to table
167
+ columns = list(x for x in columns if x != "tokens")
168
+ for column in columns:
169
+ table.add_column(column)
170
+
171
+ # Add rows to table
172
+ for result in results:
173
+ if isinstance(result, tuple):
174
+ table.add_row(*(self.render(result, None, x) for x in result))
175
+ else:
176
+ table.add_row(*(self.render(result, column, result.get(column)) for column in columns))
177
+
178
+ # Print table to console
179
+ self.console.print(table)
180
+
181
+ def workflow(self, command):
182
+ """
183
+ Processes .workflow command.
184
+
185
+ Args:
186
+ command: command line
187
+ """
188
+
189
+ command = shlex.split(command)
190
+ if isinstance(self.app, Application):
191
+ self.console.print(list(self.app.workflow(command[1], command[2:])))
192
+
193
+ def isyaml(self, path):
194
+ """
195
+ Checks if file at path is a valid YAML file.
196
+
197
+ Args:
198
+ path: file to check
199
+
200
+ Returns:
201
+ True if file is valid YAML, False otherwise
202
+ """
203
+
204
+ if os.path.exists(path) and os.path.isfile(path):
205
+ try:
206
+ return Application.read(path)
207
+ # pylint: disable=W0702
208
+ except:
209
+ pass
210
+
211
+ return False
212
+
213
+ def split(self, command, default=None):
214
+ """
215
+ Splits command by whitespace.
216
+
217
+ Args:
218
+ command: command line
219
+ default: default command action
220
+
221
+ Returns:
222
+ command action
223
+ """
224
+
225
+ values = command.split(" ", 1)
226
+ return values if len(values) > 1 else (command, default)
227
+
228
+ def render(self, result, column, value):
229
+ """
230
+ Renders a search result column value.
231
+
232
+ Args:
233
+ result: result row
234
+ column: column name
235
+ value: column value
236
+ """
237
+
238
+ if isinstance(value, float):
239
+ return f"{value:.4f}"
240
+
241
+ # Explain highlighting
242
+ if column == "text" and "tokens" in result:
243
+ spans = []
244
+ for token, score in result["tokens"]:
245
+ color = None
246
+ if score >= 0.02:
247
+ color = f"b {self.vhighlight}"
248
+
249
+ spans.append((token, score, color))
250
+
251
+ if result["score"] >= 0.05 and not [color for _, _, color in spans if color]:
252
+ mscore = max(score for _, score, _ in spans)
253
+ spans = [(token, score, f"b {self.vhighlight}" if score == mscore else color) for token, score, color in spans]
254
+
255
+ output = ""
256
+ for token, _, color in spans:
257
+ if color:
258
+ output += f"[{color}]{token}[/{color}] "
259
+ else:
260
+ output += f"{token} "
261
+
262
+ return output
263
+
264
+ return str(value)
txtai/data/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ """
2
+ Data imports
3
+ """
4
+
5
+ from .base import Data
6
+ from .labels import Labels
7
+ from .questions import Questions
8
+ from .sequences import Sequences
9
+ from .texts import Texts
10
+ from .tokens import Tokens
txtai/data/base.py ADDED
@@ -0,0 +1,138 @@
1
+ """
2
+ Data module
3
+ """
4
+
5
+ from .tokens import Tokens
6
+
7
+
8
+ class Data:
9
+ """
10
+ Base data tokenization class.
11
+ """
12
+
13
+ def __init__(self, tokenizer, columns, maxlength):
14
+ """
15
+ Creates new base instance for tokenizing data.
16
+
17
+ Args:
18
+ tokenizer: model tokenizer
19
+ columns: column names
20
+ maxlength: maximum sequence length
21
+ """
22
+
23
+ self.tokenizer = tokenizer
24
+ self.columns = columns
25
+ self.maxlength = maxlength
26
+
27
+ def __call__(self, train, validation, workers):
28
+ """
29
+ Tokenizes training and validation data and returns processed datasets.
30
+
31
+ Args:
32
+ train: training data
33
+ validation: validation data
34
+ workers: number of concurrent tokenizers when processing datasets, only main process used when set to None
35
+
36
+ Returns:
37
+ (train, validation)
38
+ """
39
+
40
+ return (self.prepare(train, self.process, workers), self.prepare(validation, self.process, workers) if validation else None)
41
+
42
+ def prepare(self, data, fn, workers):
43
+ """
44
+ Prepares and tokenizes data for training.
45
+
46
+ Args:
47
+ data: input data
48
+ fn: tokenize processing function to apply
49
+ workers: number of concurrent tokenizers when processing datasets, only main process used when set to None
50
+
51
+ Returns:
52
+ tokens
53
+ """
54
+
55
+ if hasattr(data, "map"):
56
+ # Hugging Face dataset
57
+ tokens = data.map(fn, batched=True, num_proc=workers, remove_columns=data.column_names)
58
+ else:
59
+ # Re-orient data into columns for efficient batch tokenization
60
+ columns = {}
61
+ if hasattr(data, "columns"):
62
+ # Polars/pandas DataFrame
63
+ for column in data.columns:
64
+ columns[column] = list(data[column])
65
+ else:
66
+ # Iterable dicts
67
+ for row in data:
68
+ for column in row.keys():
69
+ if column not in columns:
70
+ columns[column] = []
71
+
72
+ columns[column].append(row[column])
73
+
74
+ # Process column-oriented data
75
+ tokens = Tokens(fn(columns))
76
+
77
+ return tokens
78
+
79
+ def labels(self, data):
80
+ """
81
+ Extracts a list of unique labels from data.
82
+
83
+ Args:
84
+ data: input data
85
+
86
+ Returns:
87
+ list of unique labels
88
+ """
89
+
90
+ # Last column is label
91
+ column = self.columns[-1]
92
+
93
+ # Return length of labels if it's an array
94
+ length = self.length(data[column][0] if hasattr(data, "columns") else data[0][column])
95
+ if length:
96
+ return length
97
+
98
+ if hasattr(data, "map"):
99
+ # Hugging Face dataset
100
+ labels = sorted(data.unique(self.columns[-1]))
101
+ elif hasattr(data, "columns"):
102
+ # Polars/pandas DataFrame
103
+ labels = sorted(data[self.columns[-1]].unique())
104
+ else:
105
+ # Iterable dicts
106
+ labels = sorted({row[self.columns[-1]] for row in data})
107
+
108
+ # Labels are single numeric values per entry
109
+ # - Consider a regression task if at least one label isn't an integer
110
+ # - Otherwise use number of labels for a classification task
111
+ return 1 if [x for x in labels if float(x) != int(x)] else len(labels)
112
+
113
+ def process(self, data):
114
+ """
115
+ Tokenizes batch of input data
116
+
117
+ Args:
118
+ data: input data batch
119
+
120
+ Returns:
121
+ tokenized data
122
+ """
123
+
124
+ return data
125
+
126
+ def length(self, value):
127
+ """
128
+ Returns the length of value if value has a len function defined. Otherwise,
129
+ None is returned.
130
+
131
+ Args:
132
+ value: value to check
133
+
134
+ Returns:
135
+ length of value if available, otherwise returns None
136
+ """
137
+
138
+ return len(value) if hasattr(value, "__len__") else None
txtai/data/labels.py ADDED
@@ -0,0 +1,42 @@
1
+ """
2
+ Labels module
3
+ """
4
+
5
+ from .base import Data
6
+
7
+
8
+ class Labels(Data):
9
+ """
10
+ Tokenizes text-classification datasets as input for training text-classification models.
11
+ """
12
+
13
+ def __init__(self, tokenizer, columns, maxlength):
14
+ """
15
+ Creates a new instance for tokenizing Labels training data.
16
+
17
+ Args:
18
+ tokenizer: model tokenizer
19
+ columns: tuple of columns to use for text/label
20
+ maxlength: maximum sequence length
21
+ """
22
+
23
+ super().__init__(tokenizer, columns, maxlength)
24
+
25
+ # Standardize columns
26
+ if not self.columns:
27
+ self.columns = ("text", None, "label")
28
+ elif len(columns) < 3:
29
+ self.columns = (self.columns[0], None, self.columns[-1])
30
+
31
+ def process(self, data):
32
+ # Column keys
33
+ text1, text2, label = self.columns
34
+
35
+ # Tokenizer inputs can be single string or string pair, depending on task
36
+ text = (data[text1], data[text2]) if text2 else (data[text1],)
37
+
38
+ # Tokenize text and add label
39
+ inputs = self.tokenizer(*text, max_length=self.maxlength, padding=True, truncation=True)
40
+ inputs[label] = data[label]
41
+
42
+ return inputs
@@ -0,0 +1,135 @@
1
+ """
2
+ Questions module
3
+ """
4
+
5
+ from .base import Data
6
+
7
+
8
+ class Questions(Data):
9
+ """
10
+ Tokenizes question-answering datasets as input for training question-answering models.
11
+ """
12
+
13
+ def __init__(self, tokenizer, columns, maxlength, stride):
14
+ """
15
+ Creates a new instance for tokenizing Questions training data.
16
+
17
+ Args:
18
+ tokenizer: model tokenizer
19
+ columns: tuple of columns to use for question/context/answer
20
+ maxlength: maximum sequence length
21
+ stride: chunk size for splitting data for QA tasks
22
+ """
23
+
24
+ super().__init__(tokenizer, columns, maxlength)
25
+
26
+ if not self.columns:
27
+ self.columns = ("question", "context", "answers")
28
+
29
+ self.question, self.context, self.answer = self.columns
30
+ self.stride = stride
31
+ self.rpad = tokenizer.padding_side == "right"
32
+
33
+ def process(self, data):
34
+ # Tokenize data
35
+ tokenized = self.tokenize(data)
36
+
37
+ # Get mapping of overflowing tokens and answer offsets
38
+ samples = tokenized.pop("overflow_to_sample_mapping")
39
+ offsets = tokenized.pop("offset_mapping")
40
+
41
+ # Start/end positions
42
+ tokenized["start_positions"] = []
43
+ tokenized["end_positions"] = []
44
+
45
+ for x, offset in enumerate(offsets):
46
+ # Label NO ANSWER with CLS token
47
+ inputids = tokenized["input_ids"][x]
48
+ clstoken = inputids.index(self.tokenizer.cls_token_id)
49
+
50
+ # Sequence ids
51
+ sequences = tokenized.sequence_ids(x)
52
+
53
+ # Get and format answer
54
+ answers = self.answers(data, samples[x])
55
+
56
+ # If no answers are given, set cls token as answer.
57
+ if len(answers["answer_start"]) == 0:
58
+ tokenized["start_positions"].append(clstoken)
59
+ tokenized["end_positions"].append(clstoken)
60
+ else:
61
+ # Start/end character index of the answer in the text.
62
+ startchar = answers["answer_start"][0]
63
+ endchar = startchar + len(answers["text"][0])
64
+
65
+ # Start token index of the current span in the text.
66
+ start = 0
67
+ while sequences[start] != (1 if self.rpad else 0):
68
+ start += 1
69
+
70
+ # End token index of the current span in the text.
71
+ end = len(inputids) - 1
72
+ while sequences[end] != (1 if self.rpad else 0):
73
+ end -= 1
74
+
75
+ # Map start character and end character to matching token index
76
+ while start < len(offset) and offset[start][0] <= startchar:
77
+ start += 1
78
+ tokenized["start_positions"].append(start - 1)
79
+
80
+ while offset[end][1] >= endchar:
81
+ end -= 1
82
+ tokenized["end_positions"].append(end + 1)
83
+
84
+ return tokenized
85
+
86
+ def tokenize(self, data):
87
+ """
88
+ Tokenizes batch of data
89
+
90
+ Args:
91
+ data: input data batch
92
+
93
+ Returns:
94
+ tokenized data
95
+ """
96
+
97
+ # Trim question whitespace
98
+ data[self.question] = [x.lstrip() for x in data[self.question]]
99
+
100
+ # Tokenize records
101
+ return self.tokenizer(
102
+ data[self.question if self.rpad else self.context],
103
+ data[self.context if self.rpad else self.question],
104
+ truncation="only_second" if self.rpad else "only_first",
105
+ max_length=self.maxlength,
106
+ stride=self.stride,
107
+ return_overflowing_tokens=True,
108
+ return_offsets_mapping=True,
109
+ padding=True,
110
+ )
111
+
112
+ def answers(self, data, index):
113
+ """
114
+ Gets and formats an answer.
115
+
116
+ Args:
117
+ data: input examples
118
+ index: answer index to retrieve
119
+
120
+ Returns:
121
+ answers dict
122
+ """
123
+
124
+ # Answer mappings
125
+ answers = data[self.answer][index]
126
+ context = data[self.context][index]
127
+
128
+ # Handle mapping string answers to dict
129
+ if not isinstance(answers, dict):
130
+ if not answers:
131
+ answers = {"text": [], "answer_start": []}
132
+ else:
133
+ answers = {"text": [answers], "answer_start": [context.index(answers)]}
134
+
135
+ return answers
@@ -0,0 +1,48 @@
1
+ """
2
+ Sequences module
3
+ """
4
+
5
+ from .base import Data
6
+
7
+
8
+ class Sequences(Data):
9
+ """
10
+ Tokenizes sequence-sequence datasets as input for training sequence-sequence models
11
+ """
12
+
13
+ def __init__(self, tokenizer, columns, maxlength, prefix):
14
+ """
15
+ Creates a new instance for tokenizing Sequences training data.
16
+
17
+ Args:
18
+ tokenizer: model tokenizer
19
+ columns: tuple of columns to use for text/label
20
+ maxlength: maximum sequence length
21
+ prefix: source prefix
22
+ """
23
+
24
+ super().__init__(tokenizer, columns, maxlength)
25
+
26
+ # Standardize columns
27
+ if not self.columns:
28
+ self.columns = ("source", "target")
29
+
30
+ # Save source prefix
31
+ self.prefix = prefix
32
+
33
+ def process(self, data):
34
+ # Column keys
35
+ source, target = self.columns
36
+
37
+ # Tokenize source
38
+ source = [self.prefix + x if self.prefix else x for x in data[source]]
39
+ inputs = self.tokenizer(source, max_length=self.maxlength, padding=False, truncation=True)
40
+
41
+ # Tokenize target
42
+ with self.tokenizer.as_target_tokenizer():
43
+ targets = self.tokenizer(data[target], max_length=self.maxlength, padding=False, truncation=True)
44
+
45
+ # Combine inputs
46
+ inputs["labels"] = targets["input_ids"]
47
+
48
+ return inputs