mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,155 @@
|
|
1
|
+
"""
|
2
|
+
Tabular module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
# Conditional import
|
8
|
+
try:
|
9
|
+
import pandas as pd
|
10
|
+
|
11
|
+
PANDAS = True
|
12
|
+
except ImportError:
|
13
|
+
PANDAS = False
|
14
|
+
|
15
|
+
from ..base import Pipeline
|
16
|
+
|
17
|
+
|
18
|
+
class Tabular(Pipeline):
|
19
|
+
"""
|
20
|
+
Splits tabular data into rows and columns.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, idcolumn=None, textcolumns=None, content=False):
|
24
|
+
"""
|
25
|
+
Creates a new Tabular pipeline.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
idcolumn: column name to use for row id
|
29
|
+
textcolumns: list of columns to combine as a text field
|
30
|
+
content: if True, a dict per row is generated with all fields. If content is a list, a subset of fields
|
31
|
+
is included in the generated rows.
|
32
|
+
"""
|
33
|
+
|
34
|
+
if not PANDAS:
|
35
|
+
raise ImportError('Tabular pipeline is not available - install "pipeline" extra to enable')
|
36
|
+
|
37
|
+
self.idcolumn = idcolumn
|
38
|
+
self.textcolumns = textcolumns
|
39
|
+
self.content = content
|
40
|
+
|
41
|
+
def __call__(self, data):
|
42
|
+
"""
|
43
|
+
Splits data into rows and columns.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
data: input data
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
list of (id, text, tag)
|
50
|
+
"""
|
51
|
+
|
52
|
+
items = [data] if not isinstance(data, list) else data
|
53
|
+
|
54
|
+
# Combine all rows into single return element
|
55
|
+
results = []
|
56
|
+
dicts = []
|
57
|
+
|
58
|
+
for item in items:
|
59
|
+
# File path
|
60
|
+
if isinstance(item, str):
|
61
|
+
_, extension = os.path.splitext(item)
|
62
|
+
extension = extension.replace(".", "").lower()
|
63
|
+
|
64
|
+
if extension == "csv":
|
65
|
+
df = pd.read_csv(item)
|
66
|
+
|
67
|
+
results.append(self.process(df))
|
68
|
+
|
69
|
+
# Dict
|
70
|
+
if isinstance(item, dict):
|
71
|
+
dicts.append(item)
|
72
|
+
|
73
|
+
# List of dicts
|
74
|
+
elif isinstance(item, list):
|
75
|
+
df = pd.DataFrame(item)
|
76
|
+
results.append(self.process(df))
|
77
|
+
|
78
|
+
if dicts:
|
79
|
+
df = pd.DataFrame(dicts)
|
80
|
+
results.extend(self.process(df))
|
81
|
+
|
82
|
+
return results[0] if not isinstance(data, list) else results
|
83
|
+
|
84
|
+
def process(self, df):
|
85
|
+
"""
|
86
|
+
Extracts a list of (id, text, tag) tuples from a dataframe.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
df: DataFrame to extract content from
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
list of (id, text, tag)
|
93
|
+
"""
|
94
|
+
|
95
|
+
rows = []
|
96
|
+
|
97
|
+
# Columns to use for text
|
98
|
+
columns = self.textcolumns
|
99
|
+
if not columns:
|
100
|
+
columns = list(df.columns)
|
101
|
+
if self.idcolumn:
|
102
|
+
columns.remove(self.idcolumn)
|
103
|
+
|
104
|
+
# Transform into (id, text, tag) tuples
|
105
|
+
for index, row in df.iterrows():
|
106
|
+
uid = row[self.idcolumn] if self.idcolumn else index
|
107
|
+
uid = uid if uid is not None else index
|
108
|
+
text = self.concat(row, columns)
|
109
|
+
|
110
|
+
rows.append((uid, text, None))
|
111
|
+
|
112
|
+
# Also add row for content
|
113
|
+
if isinstance(self.content, list):
|
114
|
+
row = {column: self.column(value) for column, value in row.to_dict().items() if column in self.content}
|
115
|
+
rows.append((uid, row, None))
|
116
|
+
elif self.content:
|
117
|
+
row = {column: self.column(value) for column, value in row.to_dict().items()}
|
118
|
+
rows.append((uid, row, None))
|
119
|
+
|
120
|
+
return rows
|
121
|
+
|
122
|
+
def concat(self, row, columns):
|
123
|
+
"""
|
124
|
+
Builds a text field from row using columns.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
row: input row
|
128
|
+
columns: list of columns to join together
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
text
|
132
|
+
"""
|
133
|
+
|
134
|
+
parts = []
|
135
|
+
for column in columns:
|
136
|
+
column = self.column(row[column])
|
137
|
+
if column:
|
138
|
+
parts.append(str(column))
|
139
|
+
|
140
|
+
return ". ".join(parts) if parts else None
|
141
|
+
|
142
|
+
def column(self, value):
|
143
|
+
"""
|
144
|
+
Applies column standardization logic:
|
145
|
+
- Replace NaN values with None
|
146
|
+
|
147
|
+
Args:
|
148
|
+
value: input value
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
formatted value
|
152
|
+
"""
|
153
|
+
|
154
|
+
# Check for null - treat lists as not null
|
155
|
+
return None if not isinstance(value, list) and pd.isnull(value) else value
|
@@ -0,0 +1,139 @@
|
|
1
|
+
"""
|
2
|
+
Textractor module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import contextlib
|
6
|
+
import os
|
7
|
+
import tempfile
|
8
|
+
|
9
|
+
from urllib.parse import urlparse
|
10
|
+
from urllib.request import urlopen, Request
|
11
|
+
|
12
|
+
from .filetohtml import FileToHTML
|
13
|
+
from .htmltomd import HTMLToMarkdown
|
14
|
+
from .segmentation import Segmentation
|
15
|
+
|
16
|
+
|
17
|
+
class Textractor(Segmentation):
|
18
|
+
"""
|
19
|
+
Extracts text from files.
|
20
|
+
"""
|
21
|
+
|
22
|
+
# pylint: disable=R0913
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
sentences=False,
|
26
|
+
lines=False,
|
27
|
+
paragraphs=False,
|
28
|
+
minlength=None,
|
29
|
+
join=False,
|
30
|
+
sections=False,
|
31
|
+
cleantext=True,
|
32
|
+
chunker=None,
|
33
|
+
headers=None,
|
34
|
+
backend="available",
|
35
|
+
**kwargs
|
36
|
+
):
|
37
|
+
super().__init__(sentences, lines, paragraphs, minlength, join, sections, cleantext, chunker, **kwargs)
|
38
|
+
|
39
|
+
# Get backend parameter - handle legacy tika flag
|
40
|
+
backend = "tika" if "tika" in kwargs and kwargs["tika"] else None if "tika" in kwargs else backend
|
41
|
+
|
42
|
+
# File to HTML pipeline
|
43
|
+
self.html = FileToHTML(backend) if backend else None
|
44
|
+
|
45
|
+
# HTML to Markdown pipeline
|
46
|
+
self.markdown = HTMLToMarkdown(self.paragraphs, self.sections)
|
47
|
+
|
48
|
+
# HTTP headers
|
49
|
+
self.headers = headers if headers else {}
|
50
|
+
|
51
|
+
def text(self, text):
|
52
|
+
# Check if text is a valid file path or url
|
53
|
+
path, exists = self.valid(text)
|
54
|
+
|
55
|
+
if not path:
|
56
|
+
# Not a valid file path, treat input as data
|
57
|
+
html = text
|
58
|
+
|
59
|
+
elif self.html:
|
60
|
+
# Use FileToHTML pipeline, if available
|
61
|
+
# Retrieve remote file, if necessary
|
62
|
+
path = path if exists else self.download(path)
|
63
|
+
|
64
|
+
# Parse content to HTML
|
65
|
+
html = self.html(path)
|
66
|
+
|
67
|
+
# FiletoHTML pipeline returns None when input is already HTML
|
68
|
+
html = html if html else self.retrieve(path)
|
69
|
+
|
70
|
+
# Delete temporary file
|
71
|
+
if not exists:
|
72
|
+
os.remove(path)
|
73
|
+
|
74
|
+
else:
|
75
|
+
# Read data from url/path
|
76
|
+
html = self.retrieve(path)
|
77
|
+
|
78
|
+
# HTML to Markdown
|
79
|
+
return self.markdown(html)
|
80
|
+
|
81
|
+
def valid(self, path):
|
82
|
+
"""
|
83
|
+
Checks if path is a valid local file or web url. Returns path if valid along with a flag
|
84
|
+
denoting if the path exists locally.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
path: path to check
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
(path, exists)
|
91
|
+
"""
|
92
|
+
|
93
|
+
# Convert file urls to local paths
|
94
|
+
path = path.replace("file://", "")
|
95
|
+
|
96
|
+
# Check if this is a local file path or local file url
|
97
|
+
exists = os.path.exists(path)
|
98
|
+
|
99
|
+
# Consider local files and HTTP urls valid
|
100
|
+
return (path if exists or urlparse(path).scheme in ("http", "https") else None, exists)
|
101
|
+
|
102
|
+
def download(self, url):
|
103
|
+
"""
|
104
|
+
Downloads content of url to a temporary file.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
url: input url
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
temporary file path
|
111
|
+
"""
|
112
|
+
|
113
|
+
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as output:
|
114
|
+
path = output.name
|
115
|
+
|
116
|
+
# Retrieve and write data to temporary file
|
117
|
+
output.write(self.retrieve(url))
|
118
|
+
|
119
|
+
return path
|
120
|
+
|
121
|
+
def retrieve(self, url):
|
122
|
+
"""
|
123
|
+
Retrieves content from url.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
url: input url
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
data
|
130
|
+
"""
|
131
|
+
|
132
|
+
# Local file
|
133
|
+
if os.path.exists(url):
|
134
|
+
with open(url, "rb") as f:
|
135
|
+
return f.read()
|
136
|
+
|
137
|
+
# Remote file
|
138
|
+
with contextlib.closing(urlopen(Request(url, headers=self.headers))) as connection:
|
139
|
+
return connection.read()
|
@@ -0,0 +1,112 @@
|
|
1
|
+
"""
|
2
|
+
Tokenizer module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import re
|
6
|
+
import string
|
7
|
+
|
8
|
+
import regex
|
9
|
+
|
10
|
+
from ..base import Pipeline
|
11
|
+
|
12
|
+
|
13
|
+
class Tokenizer(Pipeline):
|
14
|
+
"""
|
15
|
+
Tokenizes text into tokens using one of the following methods.
|
16
|
+
|
17
|
+
1. Backwards compatible tokenization that only accepts alphanumeric tokens from the Latin alphabet.
|
18
|
+
|
19
|
+
2. Split using word boundary rules from the Unicode Text Segmentation algorithm (see Unicode Standard Annex #29).
|
20
|
+
This is similar to the standard tokenizer in Apache Lucene and works well for most languages.
|
21
|
+
"""
|
22
|
+
|
23
|
+
# fmt: off
|
24
|
+
# English Stop Word List (Standard stop words used by Apache Lucene)
|
25
|
+
STOP_WORDS = {"a", "an", "and", "are", "as", "at", "be", "but", "by", "for", "if", "in", "into", "is", "it",
|
26
|
+
"no", "not", "of", "on", "or", "such", "that", "the", "their", "then", "there", "these",
|
27
|
+
"they", "this", "to", "was", "will", "with"}
|
28
|
+
# fmt: on
|
29
|
+
|
30
|
+
@staticmethod
|
31
|
+
def tokenize(text, lowercase=True, emoji=True, alphanum=True, stopwords=True):
|
32
|
+
"""
|
33
|
+
Tokenizes text into a list of tokens. The default backwards compatible parameters filter out English stop words and only
|
34
|
+
accept alphanumeric tokens.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
text: input text
|
38
|
+
lowercase: lower cases all tokens if True, defaults to True
|
39
|
+
emoji: tokenize emoji in text if True, defaults to True
|
40
|
+
alphanum: requires 2+ character alphanumeric tokens if True, defaults to True
|
41
|
+
stopwords: removes provided stop words if a list, removes default English stop words if True, defaults to True
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
list of tokens
|
45
|
+
"""
|
46
|
+
|
47
|
+
# Create a tokenizer with backwards compatible settings
|
48
|
+
return Tokenizer(lowercase, emoji, alphanum, stopwords)(text)
|
49
|
+
|
50
|
+
def __init__(self, lowercase=True, emoji=True, alphanum=False, stopwords=False):
|
51
|
+
"""
|
52
|
+
Creates a new tokenizer. The default parameters segment text per Unicode Standard Annex #29.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
lowercase: lower cases all tokens if True, defaults to True
|
56
|
+
emoji: tokenize emoji in text if True, defaults to True
|
57
|
+
alphanum: requires 2+ character alphanumeric tokens if True, defaults to False
|
58
|
+
stopwords: removes provided stop words if a list, removes default English stop words if True, defaults to False
|
59
|
+
"""
|
60
|
+
|
61
|
+
# Lowercase
|
62
|
+
self.lowercase = lowercase
|
63
|
+
|
64
|
+
# Text segmentation
|
65
|
+
self.alphanum, self.segment = None, None
|
66
|
+
if alphanum:
|
67
|
+
# Alphanumeric regex that accepts tokens that meet following rules:
|
68
|
+
# - Strings to be at least 2 characters long AND
|
69
|
+
# - At least 1 non-trailing alpha character in string
|
70
|
+
# Note: The standard Python re module is much faster than regex for this expression
|
71
|
+
self.alphanum = re.compile(r"^\d*[a-z][\-.0-9:_a-z]{1,}$")
|
72
|
+
else:
|
73
|
+
# Text segmentation per Unicode Standard Annex #29
|
74
|
+
pattern = r"\w\p{Extended_Pictographic}\p{WB:RegionalIndicator}" if emoji else r"\w"
|
75
|
+
self.segment = regex.compile(rf"[{pattern}](?:\B\S)*", flags=regex.WORD)
|
76
|
+
|
77
|
+
# Stop words
|
78
|
+
self.stopwords = stopwords if isinstance(stopwords, list) else Tokenizer.STOP_WORDS if stopwords else False
|
79
|
+
|
80
|
+
def __call__(self, text):
|
81
|
+
"""
|
82
|
+
Tokenizes text into a list of tokens.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
text: input text
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
list of tokens
|
89
|
+
"""
|
90
|
+
|
91
|
+
# Check for None and skip processing
|
92
|
+
if text is None:
|
93
|
+
return None
|
94
|
+
|
95
|
+
# Lowercase
|
96
|
+
text = text.lower() if self.lowercase else text
|
97
|
+
|
98
|
+
if self.alphanum:
|
99
|
+
# Text segmentation using standard split
|
100
|
+
tokens = [token.strip(string.punctuation) for token in text.split()]
|
101
|
+
|
102
|
+
# Filter on alphanumeric strings.
|
103
|
+
tokens = [token for token in tokens if re.match(self.alphanum, token)]
|
104
|
+
else:
|
105
|
+
# Text segmentation per Unicode Standard Annex #29
|
106
|
+
tokens = regex.findall(self.segment, text)
|
107
|
+
|
108
|
+
# Stop words
|
109
|
+
if self.stopwords:
|
110
|
+
tokens = [token for token in tokens if token not in self.stopwords]
|
111
|
+
|
112
|
+
return tokens
|
@@ -0,0 +1,77 @@
|
|
1
|
+
"""
|
2
|
+
Pipeline factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import inspect
|
6
|
+
import sys
|
7
|
+
import types
|
8
|
+
|
9
|
+
from ..util import Resolver
|
10
|
+
|
11
|
+
from .base import Pipeline
|
12
|
+
|
13
|
+
|
14
|
+
class PipelineFactory:
|
15
|
+
"""
|
16
|
+
Pipeline factory. Creates new Pipeline instances.
|
17
|
+
"""
|
18
|
+
|
19
|
+
@staticmethod
|
20
|
+
def get(pipeline):
|
21
|
+
"""
|
22
|
+
Gets a new instance of pipeline class.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
pclass: Pipeline instance class
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
Pipeline class
|
29
|
+
"""
|
30
|
+
|
31
|
+
# Local pipeline if no package
|
32
|
+
if "." not in pipeline:
|
33
|
+
return PipelineFactory.list()[pipeline]
|
34
|
+
|
35
|
+
# Attempt to load custom pipeline
|
36
|
+
return Resolver()(pipeline)
|
37
|
+
|
38
|
+
@staticmethod
|
39
|
+
def create(config, pipeline):
|
40
|
+
"""
|
41
|
+
Creates a new Pipeline instance.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
config: Pipeline configuration
|
45
|
+
pipeline: Pipeline instance class
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
Pipeline
|
49
|
+
"""
|
50
|
+
|
51
|
+
# Resolve pipeline
|
52
|
+
pipeline = PipelineFactory.get(pipeline)
|
53
|
+
|
54
|
+
# Return functions directly, otherwise create pipeline instance
|
55
|
+
return pipeline if isinstance(pipeline, types.FunctionType) else pipeline(**config)
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def list():
|
59
|
+
"""
|
60
|
+
Lists callable pipelines.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
{short name: pipeline class}
|
64
|
+
"""
|
65
|
+
|
66
|
+
pipelines = {}
|
67
|
+
|
68
|
+
# Get handle to pipeline module
|
69
|
+
pipeline = sys.modules[".".join(__name__.split(".")[:-1])]
|
70
|
+
|
71
|
+
# Get list of callable pipelines
|
72
|
+
for x in inspect.getmembers(pipeline, inspect.isclass):
|
73
|
+
if issubclass(x[1], Pipeline) and [y for y, _ in inspect.getmembers(x[1], inspect.isfunction) if y == "__call__"]:
|
74
|
+
# short name: pipeline class
|
75
|
+
pipelines[x[0].lower()] = x[1]
|
76
|
+
|
77
|
+
return pipelines
|
@@ -0,0 +1,111 @@
|
|
1
|
+
"""
|
2
|
+
Hugging Face Transformers model wrapper module
|
3
|
+
"""
|
4
|
+
|
5
|
+
from ..models import Models
|
6
|
+
from .tensors import Tensors
|
7
|
+
|
8
|
+
|
9
|
+
class HFModel(Tensors):
|
10
|
+
"""
|
11
|
+
Pipeline backed by a Hugging Face Transformers model.
|
12
|
+
"""
|
13
|
+
|
14
|
+
def __init__(self, path=None, quantize=False, gpu=False, batch=64):
|
15
|
+
"""
|
16
|
+
Creates a new HFModel.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
path: optional path to model, accepts Hugging Face model hub id or local path,
|
20
|
+
uses default model for task if not provided
|
21
|
+
quantize: if model should be quantized, defaults to False
|
22
|
+
gpu: True/False if GPU should be enabled, also supports a GPU device id
|
23
|
+
batch: batch size used to incrementally process content
|
24
|
+
"""
|
25
|
+
|
26
|
+
# Default model path
|
27
|
+
self.path = path
|
28
|
+
|
29
|
+
# Quantization flag
|
30
|
+
self.quantization = quantize
|
31
|
+
|
32
|
+
# Get tensor device reference
|
33
|
+
self.deviceid = Models.deviceid(gpu)
|
34
|
+
self.device = Models.device(self.deviceid)
|
35
|
+
|
36
|
+
# Process batch size
|
37
|
+
self.batchsize = batch
|
38
|
+
|
39
|
+
def prepare(self, model):
|
40
|
+
"""
|
41
|
+
Prepares a model for processing. Applies dynamic quantization if necessary.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
model: input model
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
model
|
48
|
+
"""
|
49
|
+
|
50
|
+
if self.deviceid == -1 and self.quantization:
|
51
|
+
model = self.quantize(model)
|
52
|
+
|
53
|
+
return model
|
54
|
+
|
55
|
+
def tokenize(self, tokenizer, texts):
|
56
|
+
"""
|
57
|
+
Tokenizes text using tokenizer. This method handles overflowing tokens and automatically splits
|
58
|
+
them into separate elements. Indices of each element is returned to allow reconstructing the
|
59
|
+
transformed elements after running through the model.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
tokenizer: Tokenizer
|
63
|
+
texts: list of text
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
(tokenization result, indices)
|
67
|
+
"""
|
68
|
+
|
69
|
+
# Pre-process and split on newlines
|
70
|
+
batch, positions = [], []
|
71
|
+
for x, text in enumerate(texts):
|
72
|
+
elements = [t + " " for t in text.split("\n") if t]
|
73
|
+
batch.extend(elements)
|
74
|
+
positions.extend([x] * len(elements))
|
75
|
+
|
76
|
+
# Run tokenizer
|
77
|
+
tokens = tokenizer(batch, padding=True)
|
78
|
+
|
79
|
+
inputids, attention, indices = [], [], []
|
80
|
+
for x, ids in enumerate(tokens["input_ids"]):
|
81
|
+
if len(ids) > tokenizer.model_max_length:
|
82
|
+
# Remove padding characters, if any
|
83
|
+
ids = [i for i in ids if i != tokenizer.pad_token_id]
|
84
|
+
|
85
|
+
# Split into model_max_length chunks
|
86
|
+
for chunk in self.batch(ids, tokenizer.model_max_length - 1):
|
87
|
+
# Append EOS token if necessary
|
88
|
+
if chunk[-1] != tokenizer.eos_token_id:
|
89
|
+
chunk.append(tokenizer.eos_token_id)
|
90
|
+
|
91
|
+
# Set attention mask
|
92
|
+
mask = [1] * len(chunk)
|
93
|
+
|
94
|
+
# Append padding if necessary
|
95
|
+
if len(chunk) < tokenizer.model_max_length:
|
96
|
+
pad = tokenizer.model_max_length - len(chunk)
|
97
|
+
chunk.extend([tokenizer.pad_token_id] * pad)
|
98
|
+
mask.extend([0] * pad)
|
99
|
+
|
100
|
+
inputids.append(chunk)
|
101
|
+
attention.append(mask)
|
102
|
+
indices.append(positions[x])
|
103
|
+
else:
|
104
|
+
inputids.append(ids)
|
105
|
+
attention.append(tokens["attention_mask"][x])
|
106
|
+
indices.append(positions[x])
|
107
|
+
|
108
|
+
tokens = {"input_ids": inputids, "attention_mask": attention}
|
109
|
+
|
110
|
+
# pylint: disable=E1102
|
111
|
+
return ({name: self.tensor(tensor).to(self.device) for name, tensor in tokens.items()}, indices)
|
@@ -0,0 +1,96 @@
|
|
1
|
+
"""
|
2
|
+
Hugging Face Transformers pipeline wrapper module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import inspect
|
6
|
+
|
7
|
+
from transformers import pipeline
|
8
|
+
|
9
|
+
from ..models import Models
|
10
|
+
from ..util import Resolver
|
11
|
+
|
12
|
+
from .tensors import Tensors
|
13
|
+
|
14
|
+
|
15
|
+
class HFPipeline(Tensors):
|
16
|
+
"""
|
17
|
+
Light wrapper around Hugging Face Transformers pipeline component for selected tasks. Adds support for model
|
18
|
+
quantization and minor interface changes.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, task, path=None, quantize=False, gpu=False, model=None, **kwargs):
|
22
|
+
"""
|
23
|
+
Loads a new pipeline model.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
task: pipeline task or category
|
27
|
+
path: optional path to model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple.
|
28
|
+
uses default model for task if not provided
|
29
|
+
quantize: if model should be quantized, defaults to False
|
30
|
+
gpu: True/False if GPU should be enabled, also supports a GPU device id
|
31
|
+
model: optional existing pipeline model to wrap
|
32
|
+
kwargs: additional keyword arguments to pass to pipeline model
|
33
|
+
"""
|
34
|
+
|
35
|
+
if model:
|
36
|
+
# Check if input model is a Pipeline or a HF pipeline
|
37
|
+
self.pipeline = model.pipeline if isinstance(model, HFPipeline) else model
|
38
|
+
else:
|
39
|
+
# Get device
|
40
|
+
deviceid = Models.deviceid(gpu) if "device_map" not in kwargs else None
|
41
|
+
device = Models.device(deviceid) if deviceid is not None else None
|
42
|
+
|
43
|
+
# Split into model args, pipeline args
|
44
|
+
modelargs, kwargs = self.parseargs(**kwargs)
|
45
|
+
|
46
|
+
# Transformer pipeline task
|
47
|
+
if isinstance(path, (list, tuple)):
|
48
|
+
# Derive configuration, if possible
|
49
|
+
config = path[1] if path[1] and isinstance(path[1], str) else None
|
50
|
+
|
51
|
+
# Load model
|
52
|
+
model = Models.load(path[0], config, task)
|
53
|
+
|
54
|
+
self.pipeline = pipeline(task, model=model, tokenizer=path[1], device=device, model_kwargs=modelargs, **kwargs)
|
55
|
+
else:
|
56
|
+
self.pipeline = pipeline(task, model=path, device=device, model_kwargs=modelargs, **kwargs)
|
57
|
+
|
58
|
+
# Model quantization. Compresses model to int8 precision, improves runtime performance. Only supported on CPU.
|
59
|
+
if deviceid == -1 and quantize:
|
60
|
+
# pylint: disable=E1101
|
61
|
+
self.pipeline.model = self.quantize(self.pipeline.model)
|
62
|
+
|
63
|
+
# Detect unbounded tokenizer typically found in older models
|
64
|
+
Models.checklength(self.pipeline.model, self.pipeline.tokenizer)
|
65
|
+
|
66
|
+
def parseargs(self, **kwargs):
|
67
|
+
"""
|
68
|
+
Inspects the pipeline method and splits kwargs into model args and pipeline args.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
kwargs: all keyword arguments
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
(model args, pipeline args)
|
75
|
+
"""
|
76
|
+
|
77
|
+
# Get pipeline method arguments
|
78
|
+
args = inspect.getfullargspec(pipeline).args
|
79
|
+
|
80
|
+
# Resolve torch dtype, if necessary
|
81
|
+
dtype = kwargs.get("torch_dtype")
|
82
|
+
if dtype and isinstance(dtype, str) and dtype != "auto":
|
83
|
+
kwargs["torch_dtype"] = Resolver()(dtype)
|
84
|
+
|
85
|
+
# Split into modelargs and kwargs
|
86
|
+
return ({arg: value for arg, value in kwargs.items() if arg not in args}, {arg: value for arg, value in kwargs.items() if arg in args})
|
87
|
+
|
88
|
+
def maxlength(self):
|
89
|
+
"""
|
90
|
+
Gets the max length to use for generate calls.
|
91
|
+
|
92
|
+
Returns:
|
93
|
+
max length
|
94
|
+
"""
|
95
|
+
|
96
|
+
return Models.maxlength(self.pipeline.model, self.pipeline.tokenizer)
|