mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,398 @@
1
+ """
2
+ Hugging Face Transformers trainer wrapper module
3
+ """
4
+
5
+ import os
6
+ import sys
7
+
8
+ import torch
9
+
10
+ from transformers import (
11
+ AutoConfig,
12
+ AutoModelForCausalLM,
13
+ AutoModelForMaskedLM,
14
+ AutoModelForQuestionAnswering,
15
+ AutoModelForPreTraining,
16
+ AutoModelForSeq2SeqLM,
17
+ AutoModelForSequenceClassification,
18
+ AutoTokenizer,
19
+ )
20
+ from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, Trainer, set_seed
21
+ from transformers import TrainingArguments as HFTrainingArguments
22
+
23
+ # Conditional import
24
+ try:
25
+ from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
26
+
27
+ # pylint: disable=C0412
28
+ from transformers import BitsAndBytesConfig
29
+
30
+ PEFT = True
31
+ except ImportError:
32
+ PEFT = False
33
+
34
+ from ...data import Labels, Questions, Sequences, Texts
35
+ from ...models import Models, TokenDetection
36
+ from ..tensors import Tensors
37
+
38
+
39
+ class HFTrainer(Tensors):
40
+ """
41
+ Trains a new Hugging Face Transformer model using the Trainer framework.
42
+ """
43
+
44
+ # pylint: disable=R0913
45
+ def __call__(
46
+ self,
47
+ base,
48
+ train,
49
+ validation=None,
50
+ columns=None,
51
+ maxlength=None,
52
+ stride=128,
53
+ task="text-classification",
54
+ prefix=None,
55
+ metrics=None,
56
+ tokenizers=None,
57
+ checkpoint=None,
58
+ quantize=None,
59
+ lora=None,
60
+ **args
61
+ ):
62
+ """
63
+ Builds a new model using arguments.
64
+
65
+ Args:
66
+ base: path to base model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple
67
+ train: training data
68
+ validation: validation data
69
+ columns: tuple of columns to use for text/label, defaults to (text, None, label)
70
+ maxlength: maximum sequence length, defaults to tokenizer.model_max_length
71
+ stride: chunk size for splitting data for QA tasks
72
+ task: optional model task or category, determines the model type, defaults to "text-classification"
73
+ prefix: optional source prefix
74
+ metrics: optional function that computes and returns a dict of evaluation metrics
75
+ tokenizers: optional number of concurrent tokenizers, defaults to None
76
+ checkpoint: optional resume from checkpoint flag or path to checkpoint directory, defaults to None
77
+ quantize: quantization configuration to pass to base model
78
+ lora: lora configuration to pass to PEFT model
79
+ args: training arguments
80
+
81
+ Returns:
82
+ (model, tokenizer)
83
+ """
84
+
85
+ # Quantization / LoRA support
86
+ if (quantize or lora) and not PEFT:
87
+ raise ImportError('PEFT is not available - install "pipeline" extra to enable')
88
+
89
+ # Parse TrainingArguments
90
+ args = self.parse(args)
91
+
92
+ # Set seed for model reproducibility
93
+ set_seed(args.seed)
94
+
95
+ # Load model configuration, tokenizer and max sequence length
96
+ config, tokenizer, maxlength = self.load(base, maxlength)
97
+
98
+ # Default tokenizer pad token if it's not set
99
+ tokenizer.pad_token = tokenizer.pad_token if tokenizer.pad_token is not None else tokenizer.eos_token
100
+
101
+ # Prepare parameters
102
+ process, collator, labels = self.prepare(task, train, tokenizer, columns, maxlength, stride, prefix, args)
103
+
104
+ # Tokenize training and validation data
105
+ train, validation = process(train, validation, os.cpu_count() if tokenizers and isinstance(tokenizers, bool) else tokenizers)
106
+
107
+ # Create model to train
108
+ model = self.model(task, base, config, labels, tokenizer, quantize)
109
+
110
+ # Default config pad token if it's not set
111
+ model.config.pad_token_id = model.config.pad_token_id if model.config.pad_token_id is not None else model.config.eos_token_id
112
+
113
+ # Load as PEFT model, if necessary
114
+ model = self.peft(task, lora, model)
115
+
116
+ # Add model to collator
117
+ if collator:
118
+ collator.model = model
119
+
120
+ # Build trainer
121
+ trainer = Trainer(
122
+ model=model,
123
+ tokenizer=tokenizer,
124
+ data_collator=collator,
125
+ args=args,
126
+ train_dataset=train,
127
+ eval_dataset=validation if validation else None,
128
+ compute_metrics=metrics,
129
+ )
130
+
131
+ # Run training
132
+ trainer.train(resume_from_checkpoint=checkpoint)
133
+
134
+ # Run evaluation
135
+ if validation:
136
+ trainer.evaluate()
137
+
138
+ # Save model outputs
139
+ if args.should_save:
140
+ trainer.save_model()
141
+ trainer.save_state()
142
+
143
+ # Put model in eval mode to disable weight updates and return (model, tokenizer)
144
+ return (model.eval(), tokenizer)
145
+
146
+ def parse(self, updates):
147
+ """
148
+ Parses and merges custom arguments with defaults.
149
+
150
+ Args:
151
+ updates: custom arguments
152
+
153
+ Returns:
154
+ TrainingArguments
155
+ """
156
+
157
+ # Default training arguments
158
+ args = {"output_dir": "", "save_strategy": "no", "report_to": "none", "log_level": "warning", "use_cpu": not Models.hasaccelerator()}
159
+
160
+ # Apply custom arguments
161
+ args.update(updates)
162
+
163
+ return TrainingArguments(**args)
164
+
165
+ def load(self, base, maxlength):
166
+ """
167
+ Loads the base config and tokenizer.
168
+
169
+ Args:
170
+ base: base model - supports a file path or (model, tokenizer) tuple
171
+ maxlength: maximum sequence length
172
+
173
+ Returns:
174
+ (config, tokenizer, maxlength)
175
+ """
176
+
177
+ if isinstance(base, (list, tuple)):
178
+ # Unpack existing config and tokenizer
179
+ model, tokenizer = base
180
+ config = model.config
181
+ else:
182
+ # Load config
183
+ config = AutoConfig.from_pretrained(base)
184
+
185
+ # Load tokenizer
186
+ tokenizer = AutoTokenizer.from_pretrained(base)
187
+
188
+ # Detect unbounded tokenizer
189
+ Models.checklength(config, tokenizer)
190
+
191
+ # Derive max sequence length
192
+ maxlength = min(maxlength if maxlength else sys.maxsize, tokenizer.model_max_length)
193
+
194
+ return (config, tokenizer, maxlength)
195
+
196
+ def prepare(self, task, train, tokenizer, columns, maxlength, stride, prefix, args):
197
+ """
198
+ Prepares data for model training.
199
+
200
+ Args:
201
+ task: optional model task or category, determines the model type, defaults to "text-classification"
202
+ train: training data
203
+ tokenizer: model tokenizer
204
+ columns: tuple of columns to use for text/label, defaults to (text, None, label)
205
+ maxlength: maximum sequence length, defaults to tokenizer.model_max_length
206
+ stride: chunk size for splitting data for QA tasks
207
+ prefix: optional source prefix
208
+ args: training arguments
209
+ """
210
+
211
+ process, collator, labels = None, None, None
212
+
213
+ if task == "language-generation":
214
+ process = Texts(tokenizer, columns, maxlength)
215
+ collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8 if args.fp16 else None)
216
+ elif task in ("language-modeling", "token-detection"):
217
+ process = Texts(tokenizer, columns, maxlength)
218
+ collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
219
+ elif task == "question-answering":
220
+ process = Questions(tokenizer, columns, maxlength, stride)
221
+ elif task == "sequence-sequence":
222
+ process = Sequences(tokenizer, columns, maxlength, prefix)
223
+ collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
224
+ else:
225
+ process = Labels(tokenizer, columns, maxlength)
226
+ labels = process.labels(train)
227
+
228
+ return process, collator, labels
229
+
230
+ def model(self, task, base, config, labels, tokenizer, quantize):
231
+ """
232
+ Loads the base model to train.
233
+
234
+ Args:
235
+ task: optional model task or category, determines the model type, defaults to "text-classification"
236
+ base: base model - supports a file path or (model, tokenizer) tuple
237
+ config: model configuration
238
+ labels: number of labels
239
+ tokenizer: model tokenizer
240
+ quantize: quantization config
241
+
242
+ Returns:
243
+ model
244
+ """
245
+
246
+ if labels is not None:
247
+ # Add number of labels to config
248
+ config.update({"num_labels": labels})
249
+
250
+ # Format quantization configuration
251
+ quantization = self.quantization(quantize)
252
+
253
+ # Clear quantization configuration if GPU is not available
254
+ quantization = quantization if torch.cuda.is_available() else None
255
+
256
+ # pylint: disable=E1120
257
+ # Unpack existing model or create new model from config
258
+ if isinstance(base, (list, tuple)) and not isinstance(base[0], str):
259
+ return base[0]
260
+ if task == "language-generation":
261
+ return AutoModelForCausalLM.from_pretrained(base, config=config, quantization_config=quantization)
262
+ if task == "language-modeling":
263
+ return AutoModelForMaskedLM.from_pretrained(base, config=config, quantization_config=quantization)
264
+ if task == "question-answering":
265
+ return AutoModelForQuestionAnswering.from_pretrained(base, config=config, quantization_config=quantization)
266
+ if task == "sequence-sequence":
267
+ return AutoModelForSeq2SeqLM.from_pretrained(base, config=config, quantization_config=quantization)
268
+ if task == "token-detection":
269
+ return TokenDetection(
270
+ AutoModelForMaskedLM.from_pretrained(base, config=config, quantization_config=quantization),
271
+ AutoModelForPreTraining.from_pretrained(base, config=config, quantization_config=quantization),
272
+ tokenizer,
273
+ )
274
+
275
+ # Default task
276
+ return AutoModelForSequenceClassification.from_pretrained(base, config=config, quantization_config=quantization)
277
+
278
+ def quantization(self, quantize):
279
+ """
280
+ Formats and returns quantization configuration.
281
+
282
+ Args:
283
+ quantize: input quantization configuration
284
+
285
+ Returns:
286
+ formatted quantization configuration
287
+ """
288
+
289
+ if quantize:
290
+ # Default quantization settings when set to True
291
+ if isinstance(quantize, bool):
292
+ quantize = {
293
+ "load_in_4bit": True,
294
+ "bnb_4bit_use_double_quant": True,
295
+ "bnb_4bit_quant_type": "nf4",
296
+ "bnb_4bit_compute_dtype": "bfloat16",
297
+ }
298
+
299
+ # Load dictionary configuration
300
+ if isinstance(quantize, dict):
301
+ quantize = BitsAndBytesConfig(**quantize)
302
+
303
+ return quantize if quantize else None
304
+
305
+ def peft(self, task, lora, model):
306
+ """
307
+ Wraps the input model as a PEFT model if lora configuration is set.
308
+
309
+ Args:
310
+ task: optional model task or category, determines the model type, defaults to "text-classification"
311
+ lora: lora configuration
312
+ model: transformers model
313
+
314
+ Returns:
315
+ wrapped model if lora configuration set, otherwise input model is returned
316
+ """
317
+
318
+ if lora:
319
+ # Format LoRA configuration
320
+ config = self.lora(task, lora)
321
+
322
+ # Wrap as PeftModel
323
+ model = prepare_model_for_kbit_training(model)
324
+ model = get_peft_model(model, config)
325
+ model.print_trainable_parameters()
326
+
327
+ return model
328
+
329
+ def lora(self, task, lora):
330
+ """
331
+ Formats and returns LoRA configuration.
332
+
333
+ Args:
334
+ task: optional model task or category, determines the model type, defaults to "text-classification"
335
+ lora: lora configuration
336
+
337
+ Returns:
338
+ formatted lora configuration
339
+ """
340
+
341
+ if lora:
342
+ # Default lora settings when set to True
343
+ if isinstance(lora, bool):
344
+ lora = {"r": 16, "lora_alpha": 8, "target_modules": "all-linear", "lora_dropout": 0.05, "bias": "none"}
345
+
346
+ # Load dictionary configuration
347
+ if isinstance(lora, dict):
348
+ # Set task type if missing
349
+ if "task_type" not in lora:
350
+ lora["task_type"] = self.loratask(task)
351
+
352
+ lora = LoraConfig(**lora)
353
+
354
+ return lora
355
+
356
+ def loratask(self, task):
357
+ """
358
+ Looks up the corresponding LoRA task for input task.
359
+
360
+ Args:
361
+ task: optional model task or category, determines the model type, defaults to "text-classification"
362
+
363
+ Returns:
364
+ lora task
365
+ """
366
+
367
+ # Task mapping
368
+ tasks = {
369
+ "language-generation": TaskType.CAUSAL_LM,
370
+ "language-modeling": TaskType.FEATURE_EXTRACTION,
371
+ "question-answering": TaskType.QUESTION_ANS,
372
+ "sequence-sequence": TaskType.SEQ_2_SEQ_LM,
373
+ "text-classification": TaskType.SEQ_CLS,
374
+ "token-detection": TaskType.FEATURE_EXTRACTION,
375
+ }
376
+
377
+ # Default task
378
+ task = task if task in tasks else "text-classification"
379
+
380
+ # Lookup and return task
381
+ return tasks[task]
382
+
383
+
384
+ class TrainingArguments(HFTrainingArguments):
385
+ """
386
+ Extends standard TrainingArguments to make the output directory optional for transient models.
387
+ """
388
+
389
+ @property
390
+ def should_save(self):
391
+ """
392
+ Override should_save to disable model saving when output directory is None.
393
+
394
+ Returns:
395
+ If model should be saved
396
+ """
397
+
398
+ return super().should_save if self.output_dir else False
@@ -0,0 +1,63 @@
1
+ """
2
+ Machine learning model to ONNX export module
3
+ """
4
+
5
+ from ..base import Pipeline
6
+
7
+ try:
8
+ from onnxmltools import convert_sklearn
9
+
10
+ from skl2onnx.common.data_types import StringTensorType
11
+ from skl2onnx.helpers.onnx_helper import save_onnx_model, select_model_inputs_outputs
12
+
13
+ ONNX_MLTOOLS = True
14
+ except ImportError:
15
+ ONNX_MLTOOLS = False
16
+
17
+
18
+ class MLOnnx(Pipeline):
19
+ """
20
+ Exports a machine learning model to ONNX using ONNXMLTools.
21
+ """
22
+
23
+ def __init__(self):
24
+ """
25
+ Creates a new MLOnnx pipeline.
26
+ """
27
+
28
+ if not ONNX_MLTOOLS:
29
+ raise ImportError('MLOnnx pipeline is not available - install "pipeline" extra to enable')
30
+
31
+ def __call__(self, model, task="default", output=None, opset=12):
32
+ """
33
+ Exports a machine learning model to ONNX using ONNXMLTools.
34
+
35
+ Args:
36
+ model: model to export
37
+ task: optional model task or category
38
+ output: optional output model path, defaults to return byte array if None
39
+ opset: onnx opset, defaults to 12
40
+
41
+ Returns:
42
+ path to model output or model as bytes depending on output parameter
43
+ """
44
+
45
+ # Convert scikit-learn model to ONNX
46
+ model = convert_sklearn(model, task, initial_types=[("input_ids", StringTensorType([None, None]))], target_opset=opset)
47
+
48
+ # Prune model graph down to only output probabilities
49
+ model = select_model_inputs_outputs(model, outputs="probabilities")
50
+
51
+ # pylint: disable=E1101
52
+ # Rename output to logits for consistency with other models
53
+ model.graph.output[0].name = "logits"
54
+
55
+ # Find probabilities output node and rename to logits
56
+ for node in model.graph.node:
57
+ for x, _ in enumerate(node.output):
58
+ if node.output[x] == "probabilities":
59
+ node.output[x] = "logits"
60
+
61
+ # Save model to specified output path or return bytes
62
+ model = save_onnx_model(model, output)
63
+ return output if output else model
@@ -0,0 +1,12 @@
1
+ """
2
+ Scoring imports
3
+ """
4
+
5
+ from .base import Scoring
6
+ from .bm25 import BM25
7
+ from .factory import ScoringFactory
8
+ from .pgtext import PGText
9
+ from .sif import SIF
10
+ from .sparse import Sparse
11
+ from .terms import Terms
12
+ from .tfidf import TFIDF
txtai/scoring/base.py ADDED
@@ -0,0 +1,188 @@
1
+ """
2
+ Scoring module
3
+ """
4
+
5
+
6
+ class Scoring:
7
+ """
8
+ Base scoring.
9
+ """
10
+
11
+ def __init__(self, config=None):
12
+ """
13
+ Creates a new Scoring instance.
14
+
15
+ Args:
16
+ config: input configuration
17
+ """
18
+
19
+ # Scoring configuration
20
+ self.config = config if config is not None else {}
21
+
22
+ # Transform columns
23
+ columns = self.config.get("columns", {})
24
+ self.text = columns.get("text", "text")
25
+ self.object = columns.get("object", "object")
26
+
27
+ # Vector model, if available
28
+ self.model = None
29
+
30
+ def insert(self, documents, index=None, checkpoint=None):
31
+ """
32
+ Inserts documents into the scoring index.
33
+
34
+ Args:
35
+ documents: list of (id, dict|text|tokens, tags)
36
+ index: indexid offset
37
+ checkpoint: optional checkpoint directory, enables indexing restart
38
+ """
39
+
40
+ raise NotImplementedError
41
+
42
+ def delete(self, ids):
43
+ """
44
+ Deletes documents from scoring index.
45
+
46
+ Args:
47
+ ids: list of ids to delete
48
+ """
49
+
50
+ raise NotImplementedError
51
+
52
+ def index(self, documents=None):
53
+ """
54
+ Indexes a collection of documents using a scoring method.
55
+
56
+ Args:
57
+ documents: list of (id, dict|text|tokens, tags)
58
+ """
59
+
60
+ # Insert documents
61
+ if documents:
62
+ self.insert(documents)
63
+
64
+ def upsert(self, documents=None):
65
+ """
66
+ Convience method for API clarity. Calls index method.
67
+
68
+ Args:
69
+ documents: list of (id, dict|text|tokens, tags)
70
+ """
71
+
72
+ self.index(documents)
73
+
74
+ def weights(self, tokens):
75
+ """
76
+ Builds a weights vector for each token in input tokens.
77
+
78
+ Args:
79
+ tokens: input tokens
80
+
81
+ Returns:
82
+ list of weights for each token
83
+ """
84
+
85
+ raise NotImplementedError
86
+
87
+ def search(self, query, limit=3):
88
+ """
89
+ Search index for documents matching query.
90
+
91
+ Args:
92
+ query: input query
93
+ limit: maximum results
94
+
95
+ Returns:
96
+ list of (id, score) or (data, score) if content is enabled
97
+ """
98
+
99
+ raise NotImplementedError
100
+
101
+ def batchsearch(self, queries, limit=3, threads=True):
102
+ """
103
+ Search index for documents matching queries.
104
+
105
+ Args:
106
+ queries: queries to run
107
+ limit: maximum results
108
+ threads: run as threaded search if True and supported
109
+ """
110
+
111
+ raise NotImplementedError
112
+
113
+ def count(self):
114
+ """
115
+ Returns the total number of documents indexed.
116
+
117
+ Returns:
118
+ total number of documents indexed
119
+ """
120
+
121
+ raise NotImplementedError
122
+
123
+ def load(self, path):
124
+ """
125
+ Loads a saved Scoring object from path.
126
+
127
+ Args:
128
+ path: directory path to load scoring index
129
+ """
130
+
131
+ raise NotImplementedError
132
+
133
+ def save(self, path):
134
+ """
135
+ Saves a Scoring object to path.
136
+
137
+ Args:
138
+ path: directory path to save scoring index
139
+ """
140
+
141
+ raise NotImplementedError
142
+
143
+ def close(self):
144
+ """
145
+ Closes this Scoring object.
146
+ """
147
+
148
+ raise NotImplementedError
149
+
150
+ def findmodel(self):
151
+ """
152
+ Returns the associated vector model used by this scoring instance, if any.
153
+
154
+ Returns:
155
+ associated vector model
156
+ """
157
+
158
+ return self.model
159
+
160
+ def issparse(self):
161
+ """
162
+ Check if this scoring instance has an associated sparse keyword or sparse vector index.
163
+
164
+ Returns:
165
+ True if this index has an associated sparse index
166
+ """
167
+
168
+ raise NotImplementedError
169
+
170
+ def isweighted(self):
171
+ """
172
+ Check if this scoring instance is for term weighting (i.e.) it has no associated sparse index.
173
+
174
+ Returns:
175
+ True if this index is for term weighting
176
+ """
177
+
178
+ return not self.issparse()
179
+
180
+ def isnormalized(self):
181
+ """
182
+ Check if this scoring instance returns normalized scores.
183
+
184
+ Returns:
185
+ True if normalize is enabled, False otherwise
186
+ """
187
+
188
+ raise NotImplementedError
txtai/scoring/bm25.py ADDED
@@ -0,0 +1,29 @@
1
+ """
2
+ BM25 module
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ from .tfidf import TFIDF
8
+
9
+
10
+ class BM25(TFIDF):
11
+ """
12
+ Best matching (BM25) scoring.
13
+ """
14
+
15
+ def __init__(self, config=None):
16
+ super().__init__(config)
17
+
18
+ # BM25 configurable parameters
19
+ self.k1 = self.config.get("k1", 1.2)
20
+ self.b = self.config.get("b", 0.75)
21
+
22
+ def computeidf(self, freq):
23
+ # Calculate BM25 IDF score
24
+ return np.log(1 + (self.total - freq + 0.5) / (freq + 0.5))
25
+
26
+ def score(self, freq, idf, length):
27
+ # Calculate BM25 score
28
+ k = self.k1 * ((1 - self.b) + self.b * length / self.avgdl)
29
+ return idf * (freq * (self.k1 + 1)) / (freq + k)