mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,398 @@
|
|
1
|
+
"""
|
2
|
+
Hugging Face Transformers trainer wrapper module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import sys
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from transformers import (
|
11
|
+
AutoConfig,
|
12
|
+
AutoModelForCausalLM,
|
13
|
+
AutoModelForMaskedLM,
|
14
|
+
AutoModelForQuestionAnswering,
|
15
|
+
AutoModelForPreTraining,
|
16
|
+
AutoModelForSeq2SeqLM,
|
17
|
+
AutoModelForSequenceClassification,
|
18
|
+
AutoTokenizer,
|
19
|
+
)
|
20
|
+
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, Trainer, set_seed
|
21
|
+
from transformers import TrainingArguments as HFTrainingArguments
|
22
|
+
|
23
|
+
# Conditional import
|
24
|
+
try:
|
25
|
+
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
26
|
+
|
27
|
+
# pylint: disable=C0412
|
28
|
+
from transformers import BitsAndBytesConfig
|
29
|
+
|
30
|
+
PEFT = True
|
31
|
+
except ImportError:
|
32
|
+
PEFT = False
|
33
|
+
|
34
|
+
from ...data import Labels, Questions, Sequences, Texts
|
35
|
+
from ...models import Models, TokenDetection
|
36
|
+
from ..tensors import Tensors
|
37
|
+
|
38
|
+
|
39
|
+
class HFTrainer(Tensors):
|
40
|
+
"""
|
41
|
+
Trains a new Hugging Face Transformer model using the Trainer framework.
|
42
|
+
"""
|
43
|
+
|
44
|
+
# pylint: disable=R0913
|
45
|
+
def __call__(
|
46
|
+
self,
|
47
|
+
base,
|
48
|
+
train,
|
49
|
+
validation=None,
|
50
|
+
columns=None,
|
51
|
+
maxlength=None,
|
52
|
+
stride=128,
|
53
|
+
task="text-classification",
|
54
|
+
prefix=None,
|
55
|
+
metrics=None,
|
56
|
+
tokenizers=None,
|
57
|
+
checkpoint=None,
|
58
|
+
quantize=None,
|
59
|
+
lora=None,
|
60
|
+
**args
|
61
|
+
):
|
62
|
+
"""
|
63
|
+
Builds a new model using arguments.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
base: path to base model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple
|
67
|
+
train: training data
|
68
|
+
validation: validation data
|
69
|
+
columns: tuple of columns to use for text/label, defaults to (text, None, label)
|
70
|
+
maxlength: maximum sequence length, defaults to tokenizer.model_max_length
|
71
|
+
stride: chunk size for splitting data for QA tasks
|
72
|
+
task: optional model task or category, determines the model type, defaults to "text-classification"
|
73
|
+
prefix: optional source prefix
|
74
|
+
metrics: optional function that computes and returns a dict of evaluation metrics
|
75
|
+
tokenizers: optional number of concurrent tokenizers, defaults to None
|
76
|
+
checkpoint: optional resume from checkpoint flag or path to checkpoint directory, defaults to None
|
77
|
+
quantize: quantization configuration to pass to base model
|
78
|
+
lora: lora configuration to pass to PEFT model
|
79
|
+
args: training arguments
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
(model, tokenizer)
|
83
|
+
"""
|
84
|
+
|
85
|
+
# Quantization / LoRA support
|
86
|
+
if (quantize or lora) and not PEFT:
|
87
|
+
raise ImportError('PEFT is not available - install "pipeline" extra to enable')
|
88
|
+
|
89
|
+
# Parse TrainingArguments
|
90
|
+
args = self.parse(args)
|
91
|
+
|
92
|
+
# Set seed for model reproducibility
|
93
|
+
set_seed(args.seed)
|
94
|
+
|
95
|
+
# Load model configuration, tokenizer and max sequence length
|
96
|
+
config, tokenizer, maxlength = self.load(base, maxlength)
|
97
|
+
|
98
|
+
# Default tokenizer pad token if it's not set
|
99
|
+
tokenizer.pad_token = tokenizer.pad_token if tokenizer.pad_token is not None else tokenizer.eos_token
|
100
|
+
|
101
|
+
# Prepare parameters
|
102
|
+
process, collator, labels = self.prepare(task, train, tokenizer, columns, maxlength, stride, prefix, args)
|
103
|
+
|
104
|
+
# Tokenize training and validation data
|
105
|
+
train, validation = process(train, validation, os.cpu_count() if tokenizers and isinstance(tokenizers, bool) else tokenizers)
|
106
|
+
|
107
|
+
# Create model to train
|
108
|
+
model = self.model(task, base, config, labels, tokenizer, quantize)
|
109
|
+
|
110
|
+
# Default config pad token if it's not set
|
111
|
+
model.config.pad_token_id = model.config.pad_token_id if model.config.pad_token_id is not None else model.config.eos_token_id
|
112
|
+
|
113
|
+
# Load as PEFT model, if necessary
|
114
|
+
model = self.peft(task, lora, model)
|
115
|
+
|
116
|
+
# Add model to collator
|
117
|
+
if collator:
|
118
|
+
collator.model = model
|
119
|
+
|
120
|
+
# Build trainer
|
121
|
+
trainer = Trainer(
|
122
|
+
model=model,
|
123
|
+
tokenizer=tokenizer,
|
124
|
+
data_collator=collator,
|
125
|
+
args=args,
|
126
|
+
train_dataset=train,
|
127
|
+
eval_dataset=validation if validation else None,
|
128
|
+
compute_metrics=metrics,
|
129
|
+
)
|
130
|
+
|
131
|
+
# Run training
|
132
|
+
trainer.train(resume_from_checkpoint=checkpoint)
|
133
|
+
|
134
|
+
# Run evaluation
|
135
|
+
if validation:
|
136
|
+
trainer.evaluate()
|
137
|
+
|
138
|
+
# Save model outputs
|
139
|
+
if args.should_save:
|
140
|
+
trainer.save_model()
|
141
|
+
trainer.save_state()
|
142
|
+
|
143
|
+
# Put model in eval mode to disable weight updates and return (model, tokenizer)
|
144
|
+
return (model.eval(), tokenizer)
|
145
|
+
|
146
|
+
def parse(self, updates):
|
147
|
+
"""
|
148
|
+
Parses and merges custom arguments with defaults.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
updates: custom arguments
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
TrainingArguments
|
155
|
+
"""
|
156
|
+
|
157
|
+
# Default training arguments
|
158
|
+
args = {"output_dir": "", "save_strategy": "no", "report_to": "none", "log_level": "warning", "use_cpu": not Models.hasaccelerator()}
|
159
|
+
|
160
|
+
# Apply custom arguments
|
161
|
+
args.update(updates)
|
162
|
+
|
163
|
+
return TrainingArguments(**args)
|
164
|
+
|
165
|
+
def load(self, base, maxlength):
|
166
|
+
"""
|
167
|
+
Loads the base config and tokenizer.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
base: base model - supports a file path or (model, tokenizer) tuple
|
171
|
+
maxlength: maximum sequence length
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
(config, tokenizer, maxlength)
|
175
|
+
"""
|
176
|
+
|
177
|
+
if isinstance(base, (list, tuple)):
|
178
|
+
# Unpack existing config and tokenizer
|
179
|
+
model, tokenizer = base
|
180
|
+
config = model.config
|
181
|
+
else:
|
182
|
+
# Load config
|
183
|
+
config = AutoConfig.from_pretrained(base)
|
184
|
+
|
185
|
+
# Load tokenizer
|
186
|
+
tokenizer = AutoTokenizer.from_pretrained(base)
|
187
|
+
|
188
|
+
# Detect unbounded tokenizer
|
189
|
+
Models.checklength(config, tokenizer)
|
190
|
+
|
191
|
+
# Derive max sequence length
|
192
|
+
maxlength = min(maxlength if maxlength else sys.maxsize, tokenizer.model_max_length)
|
193
|
+
|
194
|
+
return (config, tokenizer, maxlength)
|
195
|
+
|
196
|
+
def prepare(self, task, train, tokenizer, columns, maxlength, stride, prefix, args):
|
197
|
+
"""
|
198
|
+
Prepares data for model training.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
task: optional model task or category, determines the model type, defaults to "text-classification"
|
202
|
+
train: training data
|
203
|
+
tokenizer: model tokenizer
|
204
|
+
columns: tuple of columns to use for text/label, defaults to (text, None, label)
|
205
|
+
maxlength: maximum sequence length, defaults to tokenizer.model_max_length
|
206
|
+
stride: chunk size for splitting data for QA tasks
|
207
|
+
prefix: optional source prefix
|
208
|
+
args: training arguments
|
209
|
+
"""
|
210
|
+
|
211
|
+
process, collator, labels = None, None, None
|
212
|
+
|
213
|
+
if task == "language-generation":
|
214
|
+
process = Texts(tokenizer, columns, maxlength)
|
215
|
+
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8 if args.fp16 else None)
|
216
|
+
elif task in ("language-modeling", "token-detection"):
|
217
|
+
process = Texts(tokenizer, columns, maxlength)
|
218
|
+
collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
|
219
|
+
elif task == "question-answering":
|
220
|
+
process = Questions(tokenizer, columns, maxlength, stride)
|
221
|
+
elif task == "sequence-sequence":
|
222
|
+
process = Sequences(tokenizer, columns, maxlength, prefix)
|
223
|
+
collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8 if args.fp16 else None)
|
224
|
+
else:
|
225
|
+
process = Labels(tokenizer, columns, maxlength)
|
226
|
+
labels = process.labels(train)
|
227
|
+
|
228
|
+
return process, collator, labels
|
229
|
+
|
230
|
+
def model(self, task, base, config, labels, tokenizer, quantize):
|
231
|
+
"""
|
232
|
+
Loads the base model to train.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
task: optional model task or category, determines the model type, defaults to "text-classification"
|
236
|
+
base: base model - supports a file path or (model, tokenizer) tuple
|
237
|
+
config: model configuration
|
238
|
+
labels: number of labels
|
239
|
+
tokenizer: model tokenizer
|
240
|
+
quantize: quantization config
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
model
|
244
|
+
"""
|
245
|
+
|
246
|
+
if labels is not None:
|
247
|
+
# Add number of labels to config
|
248
|
+
config.update({"num_labels": labels})
|
249
|
+
|
250
|
+
# Format quantization configuration
|
251
|
+
quantization = self.quantization(quantize)
|
252
|
+
|
253
|
+
# Clear quantization configuration if GPU is not available
|
254
|
+
quantization = quantization if torch.cuda.is_available() else None
|
255
|
+
|
256
|
+
# pylint: disable=E1120
|
257
|
+
# Unpack existing model or create new model from config
|
258
|
+
if isinstance(base, (list, tuple)) and not isinstance(base[0], str):
|
259
|
+
return base[0]
|
260
|
+
if task == "language-generation":
|
261
|
+
return AutoModelForCausalLM.from_pretrained(base, config=config, quantization_config=quantization)
|
262
|
+
if task == "language-modeling":
|
263
|
+
return AutoModelForMaskedLM.from_pretrained(base, config=config, quantization_config=quantization)
|
264
|
+
if task == "question-answering":
|
265
|
+
return AutoModelForQuestionAnswering.from_pretrained(base, config=config, quantization_config=quantization)
|
266
|
+
if task == "sequence-sequence":
|
267
|
+
return AutoModelForSeq2SeqLM.from_pretrained(base, config=config, quantization_config=quantization)
|
268
|
+
if task == "token-detection":
|
269
|
+
return TokenDetection(
|
270
|
+
AutoModelForMaskedLM.from_pretrained(base, config=config, quantization_config=quantization),
|
271
|
+
AutoModelForPreTraining.from_pretrained(base, config=config, quantization_config=quantization),
|
272
|
+
tokenizer,
|
273
|
+
)
|
274
|
+
|
275
|
+
# Default task
|
276
|
+
return AutoModelForSequenceClassification.from_pretrained(base, config=config, quantization_config=quantization)
|
277
|
+
|
278
|
+
def quantization(self, quantize):
|
279
|
+
"""
|
280
|
+
Formats and returns quantization configuration.
|
281
|
+
|
282
|
+
Args:
|
283
|
+
quantize: input quantization configuration
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
formatted quantization configuration
|
287
|
+
"""
|
288
|
+
|
289
|
+
if quantize:
|
290
|
+
# Default quantization settings when set to True
|
291
|
+
if isinstance(quantize, bool):
|
292
|
+
quantize = {
|
293
|
+
"load_in_4bit": True,
|
294
|
+
"bnb_4bit_use_double_quant": True,
|
295
|
+
"bnb_4bit_quant_type": "nf4",
|
296
|
+
"bnb_4bit_compute_dtype": "bfloat16",
|
297
|
+
}
|
298
|
+
|
299
|
+
# Load dictionary configuration
|
300
|
+
if isinstance(quantize, dict):
|
301
|
+
quantize = BitsAndBytesConfig(**quantize)
|
302
|
+
|
303
|
+
return quantize if quantize else None
|
304
|
+
|
305
|
+
def peft(self, task, lora, model):
|
306
|
+
"""
|
307
|
+
Wraps the input model as a PEFT model if lora configuration is set.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
task: optional model task or category, determines the model type, defaults to "text-classification"
|
311
|
+
lora: lora configuration
|
312
|
+
model: transformers model
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
wrapped model if lora configuration set, otherwise input model is returned
|
316
|
+
"""
|
317
|
+
|
318
|
+
if lora:
|
319
|
+
# Format LoRA configuration
|
320
|
+
config = self.lora(task, lora)
|
321
|
+
|
322
|
+
# Wrap as PeftModel
|
323
|
+
model = prepare_model_for_kbit_training(model)
|
324
|
+
model = get_peft_model(model, config)
|
325
|
+
model.print_trainable_parameters()
|
326
|
+
|
327
|
+
return model
|
328
|
+
|
329
|
+
def lora(self, task, lora):
|
330
|
+
"""
|
331
|
+
Formats and returns LoRA configuration.
|
332
|
+
|
333
|
+
Args:
|
334
|
+
task: optional model task or category, determines the model type, defaults to "text-classification"
|
335
|
+
lora: lora configuration
|
336
|
+
|
337
|
+
Returns:
|
338
|
+
formatted lora configuration
|
339
|
+
"""
|
340
|
+
|
341
|
+
if lora:
|
342
|
+
# Default lora settings when set to True
|
343
|
+
if isinstance(lora, bool):
|
344
|
+
lora = {"r": 16, "lora_alpha": 8, "target_modules": "all-linear", "lora_dropout": 0.05, "bias": "none"}
|
345
|
+
|
346
|
+
# Load dictionary configuration
|
347
|
+
if isinstance(lora, dict):
|
348
|
+
# Set task type if missing
|
349
|
+
if "task_type" not in lora:
|
350
|
+
lora["task_type"] = self.loratask(task)
|
351
|
+
|
352
|
+
lora = LoraConfig(**lora)
|
353
|
+
|
354
|
+
return lora
|
355
|
+
|
356
|
+
def loratask(self, task):
|
357
|
+
"""
|
358
|
+
Looks up the corresponding LoRA task for input task.
|
359
|
+
|
360
|
+
Args:
|
361
|
+
task: optional model task or category, determines the model type, defaults to "text-classification"
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
lora task
|
365
|
+
"""
|
366
|
+
|
367
|
+
# Task mapping
|
368
|
+
tasks = {
|
369
|
+
"language-generation": TaskType.CAUSAL_LM,
|
370
|
+
"language-modeling": TaskType.FEATURE_EXTRACTION,
|
371
|
+
"question-answering": TaskType.QUESTION_ANS,
|
372
|
+
"sequence-sequence": TaskType.SEQ_2_SEQ_LM,
|
373
|
+
"text-classification": TaskType.SEQ_CLS,
|
374
|
+
"token-detection": TaskType.FEATURE_EXTRACTION,
|
375
|
+
}
|
376
|
+
|
377
|
+
# Default task
|
378
|
+
task = task if task in tasks else "text-classification"
|
379
|
+
|
380
|
+
# Lookup and return task
|
381
|
+
return tasks[task]
|
382
|
+
|
383
|
+
|
384
|
+
class TrainingArguments(HFTrainingArguments):
|
385
|
+
"""
|
386
|
+
Extends standard TrainingArguments to make the output directory optional for transient models.
|
387
|
+
"""
|
388
|
+
|
389
|
+
@property
|
390
|
+
def should_save(self):
|
391
|
+
"""
|
392
|
+
Override should_save to disable model saving when output directory is None.
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
If model should be saved
|
396
|
+
"""
|
397
|
+
|
398
|
+
return super().should_save if self.output_dir else False
|
@@ -0,0 +1,63 @@
|
|
1
|
+
"""
|
2
|
+
Machine learning model to ONNX export module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..base import Pipeline
|
6
|
+
|
7
|
+
try:
|
8
|
+
from onnxmltools import convert_sklearn
|
9
|
+
|
10
|
+
from skl2onnx.common.data_types import StringTensorType
|
11
|
+
from skl2onnx.helpers.onnx_helper import save_onnx_model, select_model_inputs_outputs
|
12
|
+
|
13
|
+
ONNX_MLTOOLS = True
|
14
|
+
except ImportError:
|
15
|
+
ONNX_MLTOOLS = False
|
16
|
+
|
17
|
+
|
18
|
+
class MLOnnx(Pipeline):
|
19
|
+
"""
|
20
|
+
Exports a machine learning model to ONNX using ONNXMLTools.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self):
|
24
|
+
"""
|
25
|
+
Creates a new MLOnnx pipeline.
|
26
|
+
"""
|
27
|
+
|
28
|
+
if not ONNX_MLTOOLS:
|
29
|
+
raise ImportError('MLOnnx pipeline is not available - install "pipeline" extra to enable')
|
30
|
+
|
31
|
+
def __call__(self, model, task="default", output=None, opset=12):
|
32
|
+
"""
|
33
|
+
Exports a machine learning model to ONNX using ONNXMLTools.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
model: model to export
|
37
|
+
task: optional model task or category
|
38
|
+
output: optional output model path, defaults to return byte array if None
|
39
|
+
opset: onnx opset, defaults to 12
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
path to model output or model as bytes depending on output parameter
|
43
|
+
"""
|
44
|
+
|
45
|
+
# Convert scikit-learn model to ONNX
|
46
|
+
model = convert_sklearn(model, task, initial_types=[("input_ids", StringTensorType([None, None]))], target_opset=opset)
|
47
|
+
|
48
|
+
# Prune model graph down to only output probabilities
|
49
|
+
model = select_model_inputs_outputs(model, outputs="probabilities")
|
50
|
+
|
51
|
+
# pylint: disable=E1101
|
52
|
+
# Rename output to logits for consistency with other models
|
53
|
+
model.graph.output[0].name = "logits"
|
54
|
+
|
55
|
+
# Find probabilities output node and rename to logits
|
56
|
+
for node in model.graph.node:
|
57
|
+
for x, _ in enumerate(node.output):
|
58
|
+
if node.output[x] == "probabilities":
|
59
|
+
node.output[x] = "logits"
|
60
|
+
|
61
|
+
# Save model to specified output path or return bytes
|
62
|
+
model = save_onnx_model(model, output)
|
63
|
+
return output if output else model
|
@@ -0,0 +1,12 @@
|
|
1
|
+
"""
|
2
|
+
Scoring imports
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .base import Scoring
|
6
|
+
from .bm25 import BM25
|
7
|
+
from .factory import ScoringFactory
|
8
|
+
from .pgtext import PGText
|
9
|
+
from .sif import SIF
|
10
|
+
from .sparse import Sparse
|
11
|
+
from .terms import Terms
|
12
|
+
from .tfidf import TFIDF
|
txtai/scoring/base.py
ADDED
@@ -0,0 +1,188 @@
|
|
1
|
+
"""
|
2
|
+
Scoring module
|
3
|
+
"""
|
4
|
+
|
5
|
+
|
6
|
+
class Scoring:
|
7
|
+
"""
|
8
|
+
Base scoring.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, config=None):
|
12
|
+
"""
|
13
|
+
Creates a new Scoring instance.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
config: input configuration
|
17
|
+
"""
|
18
|
+
|
19
|
+
# Scoring configuration
|
20
|
+
self.config = config if config is not None else {}
|
21
|
+
|
22
|
+
# Transform columns
|
23
|
+
columns = self.config.get("columns", {})
|
24
|
+
self.text = columns.get("text", "text")
|
25
|
+
self.object = columns.get("object", "object")
|
26
|
+
|
27
|
+
# Vector model, if available
|
28
|
+
self.model = None
|
29
|
+
|
30
|
+
def insert(self, documents, index=None, checkpoint=None):
|
31
|
+
"""
|
32
|
+
Inserts documents into the scoring index.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
documents: list of (id, dict|text|tokens, tags)
|
36
|
+
index: indexid offset
|
37
|
+
checkpoint: optional checkpoint directory, enables indexing restart
|
38
|
+
"""
|
39
|
+
|
40
|
+
raise NotImplementedError
|
41
|
+
|
42
|
+
def delete(self, ids):
|
43
|
+
"""
|
44
|
+
Deletes documents from scoring index.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
ids: list of ids to delete
|
48
|
+
"""
|
49
|
+
|
50
|
+
raise NotImplementedError
|
51
|
+
|
52
|
+
def index(self, documents=None):
|
53
|
+
"""
|
54
|
+
Indexes a collection of documents using a scoring method.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
documents: list of (id, dict|text|tokens, tags)
|
58
|
+
"""
|
59
|
+
|
60
|
+
# Insert documents
|
61
|
+
if documents:
|
62
|
+
self.insert(documents)
|
63
|
+
|
64
|
+
def upsert(self, documents=None):
|
65
|
+
"""
|
66
|
+
Convience method for API clarity. Calls index method.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
documents: list of (id, dict|text|tokens, tags)
|
70
|
+
"""
|
71
|
+
|
72
|
+
self.index(documents)
|
73
|
+
|
74
|
+
def weights(self, tokens):
|
75
|
+
"""
|
76
|
+
Builds a weights vector for each token in input tokens.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
tokens: input tokens
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
list of weights for each token
|
83
|
+
"""
|
84
|
+
|
85
|
+
raise NotImplementedError
|
86
|
+
|
87
|
+
def search(self, query, limit=3):
|
88
|
+
"""
|
89
|
+
Search index for documents matching query.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
query: input query
|
93
|
+
limit: maximum results
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
list of (id, score) or (data, score) if content is enabled
|
97
|
+
"""
|
98
|
+
|
99
|
+
raise NotImplementedError
|
100
|
+
|
101
|
+
def batchsearch(self, queries, limit=3, threads=True):
|
102
|
+
"""
|
103
|
+
Search index for documents matching queries.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
queries: queries to run
|
107
|
+
limit: maximum results
|
108
|
+
threads: run as threaded search if True and supported
|
109
|
+
"""
|
110
|
+
|
111
|
+
raise NotImplementedError
|
112
|
+
|
113
|
+
def count(self):
|
114
|
+
"""
|
115
|
+
Returns the total number of documents indexed.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
total number of documents indexed
|
119
|
+
"""
|
120
|
+
|
121
|
+
raise NotImplementedError
|
122
|
+
|
123
|
+
def load(self, path):
|
124
|
+
"""
|
125
|
+
Loads a saved Scoring object from path.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
path: directory path to load scoring index
|
129
|
+
"""
|
130
|
+
|
131
|
+
raise NotImplementedError
|
132
|
+
|
133
|
+
def save(self, path):
|
134
|
+
"""
|
135
|
+
Saves a Scoring object to path.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
path: directory path to save scoring index
|
139
|
+
"""
|
140
|
+
|
141
|
+
raise NotImplementedError
|
142
|
+
|
143
|
+
def close(self):
|
144
|
+
"""
|
145
|
+
Closes this Scoring object.
|
146
|
+
"""
|
147
|
+
|
148
|
+
raise NotImplementedError
|
149
|
+
|
150
|
+
def findmodel(self):
|
151
|
+
"""
|
152
|
+
Returns the associated vector model used by this scoring instance, if any.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
associated vector model
|
156
|
+
"""
|
157
|
+
|
158
|
+
return self.model
|
159
|
+
|
160
|
+
def issparse(self):
|
161
|
+
"""
|
162
|
+
Check if this scoring instance has an associated sparse keyword or sparse vector index.
|
163
|
+
|
164
|
+
Returns:
|
165
|
+
True if this index has an associated sparse index
|
166
|
+
"""
|
167
|
+
|
168
|
+
raise NotImplementedError
|
169
|
+
|
170
|
+
def isweighted(self):
|
171
|
+
"""
|
172
|
+
Check if this scoring instance is for term weighting (i.e.) it has no associated sparse index.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
True if this index is for term weighting
|
176
|
+
"""
|
177
|
+
|
178
|
+
return not self.issparse()
|
179
|
+
|
180
|
+
def isnormalized(self):
|
181
|
+
"""
|
182
|
+
Check if this scoring instance returns normalized scores.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
True if normalize is enabled, False otherwise
|
186
|
+
"""
|
187
|
+
|
188
|
+
raise NotImplementedError
|
txtai/scoring/bm25.py
ADDED
@@ -0,0 +1,29 @@
|
|
1
|
+
"""
|
2
|
+
BM25 module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from .tfidf import TFIDF
|
8
|
+
|
9
|
+
|
10
|
+
class BM25(TFIDF):
|
11
|
+
"""
|
12
|
+
Best matching (BM25) scoring.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, config=None):
|
16
|
+
super().__init__(config)
|
17
|
+
|
18
|
+
# BM25 configurable parameters
|
19
|
+
self.k1 = self.config.get("k1", 1.2)
|
20
|
+
self.b = self.config.get("b", 0.75)
|
21
|
+
|
22
|
+
def computeidf(self, freq):
|
23
|
+
# Calculate BM25 IDF score
|
24
|
+
return np.log(1 + (self.total - freq + 0.5) / (freq + 0.5))
|
25
|
+
|
26
|
+
def score(self, freq, idf, length):
|
27
|
+
# Calculate BM25 score
|
28
|
+
k = self.k1 * ((1 - self.b) + self.b * length / self.avgdl)
|
29
|
+
return idf * (freq * (self.k1 + 1)) / (freq + k)
|