mseep-txtai 9.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mseep_txtai-9.1.1.dist-info/METADATA +262 -0
- mseep_txtai-9.1.1.dist-info/RECORD +251 -0
- mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
- mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
- mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
- txtai/__init__.py +16 -0
- txtai/agent/__init__.py +12 -0
- txtai/agent/base.py +54 -0
- txtai/agent/factory.py +39 -0
- txtai/agent/model.py +107 -0
- txtai/agent/placeholder.py +16 -0
- txtai/agent/tool/__init__.py +7 -0
- txtai/agent/tool/embeddings.py +69 -0
- txtai/agent/tool/factory.py +130 -0
- txtai/agent/tool/function.py +49 -0
- txtai/ann/__init__.py +7 -0
- txtai/ann/base.py +153 -0
- txtai/ann/dense/__init__.py +11 -0
- txtai/ann/dense/annoy.py +72 -0
- txtai/ann/dense/factory.py +76 -0
- txtai/ann/dense/faiss.py +233 -0
- txtai/ann/dense/hnsw.py +104 -0
- txtai/ann/dense/numpy.py +164 -0
- txtai/ann/dense/pgvector.py +323 -0
- txtai/ann/dense/sqlite.py +303 -0
- txtai/ann/dense/torch.py +38 -0
- txtai/ann/sparse/__init__.py +7 -0
- txtai/ann/sparse/factory.py +61 -0
- txtai/ann/sparse/ivfsparse.py +377 -0
- txtai/ann/sparse/pgsparse.py +56 -0
- txtai/api/__init__.py +18 -0
- txtai/api/application.py +134 -0
- txtai/api/authorization.py +53 -0
- txtai/api/base.py +159 -0
- txtai/api/cluster.py +295 -0
- txtai/api/extension.py +19 -0
- txtai/api/factory.py +40 -0
- txtai/api/responses/__init__.py +7 -0
- txtai/api/responses/factory.py +30 -0
- txtai/api/responses/json.py +56 -0
- txtai/api/responses/messagepack.py +51 -0
- txtai/api/route.py +41 -0
- txtai/api/routers/__init__.py +25 -0
- txtai/api/routers/agent.py +38 -0
- txtai/api/routers/caption.py +42 -0
- txtai/api/routers/embeddings.py +280 -0
- txtai/api/routers/entity.py +42 -0
- txtai/api/routers/extractor.py +28 -0
- txtai/api/routers/labels.py +47 -0
- txtai/api/routers/llm.py +61 -0
- txtai/api/routers/objects.py +42 -0
- txtai/api/routers/openai.py +191 -0
- txtai/api/routers/rag.py +61 -0
- txtai/api/routers/reranker.py +46 -0
- txtai/api/routers/segmentation.py +42 -0
- txtai/api/routers/similarity.py +48 -0
- txtai/api/routers/summary.py +46 -0
- txtai/api/routers/tabular.py +42 -0
- txtai/api/routers/textractor.py +42 -0
- txtai/api/routers/texttospeech.py +33 -0
- txtai/api/routers/transcription.py +42 -0
- txtai/api/routers/translation.py +46 -0
- txtai/api/routers/upload.py +36 -0
- txtai/api/routers/workflow.py +28 -0
- txtai/app/__init__.py +5 -0
- txtai/app/base.py +821 -0
- txtai/archive/__init__.py +9 -0
- txtai/archive/base.py +104 -0
- txtai/archive/compress.py +51 -0
- txtai/archive/factory.py +25 -0
- txtai/archive/tar.py +49 -0
- txtai/archive/zip.py +35 -0
- txtai/cloud/__init__.py +8 -0
- txtai/cloud/base.py +106 -0
- txtai/cloud/factory.py +70 -0
- txtai/cloud/hub.py +101 -0
- txtai/cloud/storage.py +125 -0
- txtai/console/__init__.py +5 -0
- txtai/console/__main__.py +22 -0
- txtai/console/base.py +264 -0
- txtai/data/__init__.py +10 -0
- txtai/data/base.py +138 -0
- txtai/data/labels.py +42 -0
- txtai/data/questions.py +135 -0
- txtai/data/sequences.py +48 -0
- txtai/data/texts.py +68 -0
- txtai/data/tokens.py +28 -0
- txtai/database/__init__.py +14 -0
- txtai/database/base.py +342 -0
- txtai/database/client.py +227 -0
- txtai/database/duckdb.py +150 -0
- txtai/database/embedded.py +76 -0
- txtai/database/encoder/__init__.py +8 -0
- txtai/database/encoder/base.py +37 -0
- txtai/database/encoder/factory.py +56 -0
- txtai/database/encoder/image.py +43 -0
- txtai/database/encoder/serialize.py +28 -0
- txtai/database/factory.py +77 -0
- txtai/database/rdbms.py +569 -0
- txtai/database/schema/__init__.py +6 -0
- txtai/database/schema/orm.py +99 -0
- txtai/database/schema/statement.py +98 -0
- txtai/database/sql/__init__.py +8 -0
- txtai/database/sql/aggregate.py +178 -0
- txtai/database/sql/base.py +189 -0
- txtai/database/sql/expression.py +404 -0
- txtai/database/sql/token.py +342 -0
- txtai/database/sqlite.py +57 -0
- txtai/embeddings/__init__.py +7 -0
- txtai/embeddings/base.py +1107 -0
- txtai/embeddings/index/__init__.py +14 -0
- txtai/embeddings/index/action.py +15 -0
- txtai/embeddings/index/autoid.py +92 -0
- txtai/embeddings/index/configuration.py +71 -0
- txtai/embeddings/index/documents.py +86 -0
- txtai/embeddings/index/functions.py +155 -0
- txtai/embeddings/index/indexes.py +199 -0
- txtai/embeddings/index/indexids.py +60 -0
- txtai/embeddings/index/reducer.py +104 -0
- txtai/embeddings/index/stream.py +67 -0
- txtai/embeddings/index/transform.py +205 -0
- txtai/embeddings/search/__init__.py +11 -0
- txtai/embeddings/search/base.py +344 -0
- txtai/embeddings/search/errors.py +9 -0
- txtai/embeddings/search/explain.py +120 -0
- txtai/embeddings/search/ids.py +61 -0
- txtai/embeddings/search/query.py +69 -0
- txtai/embeddings/search/scan.py +196 -0
- txtai/embeddings/search/terms.py +46 -0
- txtai/graph/__init__.py +10 -0
- txtai/graph/base.py +769 -0
- txtai/graph/factory.py +61 -0
- txtai/graph/networkx.py +275 -0
- txtai/graph/query.py +181 -0
- txtai/graph/rdbms.py +113 -0
- txtai/graph/topics.py +166 -0
- txtai/models/__init__.py +9 -0
- txtai/models/models.py +268 -0
- txtai/models/onnx.py +133 -0
- txtai/models/pooling/__init__.py +9 -0
- txtai/models/pooling/base.py +141 -0
- txtai/models/pooling/cls.py +28 -0
- txtai/models/pooling/factory.py +144 -0
- txtai/models/pooling/late.py +173 -0
- txtai/models/pooling/mean.py +33 -0
- txtai/models/pooling/muvera.py +164 -0
- txtai/models/registry.py +37 -0
- txtai/models/tokendetection.py +122 -0
- txtai/pipeline/__init__.py +17 -0
- txtai/pipeline/audio/__init__.py +11 -0
- txtai/pipeline/audio/audiomixer.py +58 -0
- txtai/pipeline/audio/audiostream.py +94 -0
- txtai/pipeline/audio/microphone.py +244 -0
- txtai/pipeline/audio/signal.py +186 -0
- txtai/pipeline/audio/texttoaudio.py +60 -0
- txtai/pipeline/audio/texttospeech.py +553 -0
- txtai/pipeline/audio/transcription.py +212 -0
- txtai/pipeline/base.py +23 -0
- txtai/pipeline/data/__init__.py +10 -0
- txtai/pipeline/data/filetohtml.py +206 -0
- txtai/pipeline/data/htmltomd.py +414 -0
- txtai/pipeline/data/segmentation.py +178 -0
- txtai/pipeline/data/tabular.py +155 -0
- txtai/pipeline/data/textractor.py +139 -0
- txtai/pipeline/data/tokenizer.py +112 -0
- txtai/pipeline/factory.py +77 -0
- txtai/pipeline/hfmodel.py +111 -0
- txtai/pipeline/hfpipeline.py +96 -0
- txtai/pipeline/image/__init__.py +7 -0
- txtai/pipeline/image/caption.py +55 -0
- txtai/pipeline/image/imagehash.py +90 -0
- txtai/pipeline/image/objects.py +80 -0
- txtai/pipeline/llm/__init__.py +11 -0
- txtai/pipeline/llm/factory.py +86 -0
- txtai/pipeline/llm/generation.py +173 -0
- txtai/pipeline/llm/huggingface.py +218 -0
- txtai/pipeline/llm/litellm.py +90 -0
- txtai/pipeline/llm/llama.py +152 -0
- txtai/pipeline/llm/llm.py +75 -0
- txtai/pipeline/llm/rag.py +477 -0
- txtai/pipeline/nop.py +14 -0
- txtai/pipeline/tensors.py +52 -0
- txtai/pipeline/text/__init__.py +13 -0
- txtai/pipeline/text/crossencoder.py +70 -0
- txtai/pipeline/text/entity.py +140 -0
- txtai/pipeline/text/labels.py +137 -0
- txtai/pipeline/text/lateencoder.py +103 -0
- txtai/pipeline/text/questions.py +48 -0
- txtai/pipeline/text/reranker.py +57 -0
- txtai/pipeline/text/similarity.py +83 -0
- txtai/pipeline/text/summary.py +98 -0
- txtai/pipeline/text/translation.py +298 -0
- txtai/pipeline/train/__init__.py +7 -0
- txtai/pipeline/train/hfonnx.py +196 -0
- txtai/pipeline/train/hftrainer.py +398 -0
- txtai/pipeline/train/mlonnx.py +63 -0
- txtai/scoring/__init__.py +12 -0
- txtai/scoring/base.py +188 -0
- txtai/scoring/bm25.py +29 -0
- txtai/scoring/factory.py +95 -0
- txtai/scoring/pgtext.py +181 -0
- txtai/scoring/sif.py +32 -0
- txtai/scoring/sparse.py +218 -0
- txtai/scoring/terms.py +499 -0
- txtai/scoring/tfidf.py +358 -0
- txtai/serialize/__init__.py +10 -0
- txtai/serialize/base.py +85 -0
- txtai/serialize/errors.py +9 -0
- txtai/serialize/factory.py +29 -0
- txtai/serialize/messagepack.py +42 -0
- txtai/serialize/pickle.py +98 -0
- txtai/serialize/serializer.py +46 -0
- txtai/util/__init__.py +7 -0
- txtai/util/resolver.py +32 -0
- txtai/util/sparsearray.py +62 -0
- txtai/util/template.py +16 -0
- txtai/vectors/__init__.py +8 -0
- txtai/vectors/base.py +476 -0
- txtai/vectors/dense/__init__.py +12 -0
- txtai/vectors/dense/external.py +55 -0
- txtai/vectors/dense/factory.py +121 -0
- txtai/vectors/dense/huggingface.py +44 -0
- txtai/vectors/dense/litellm.py +86 -0
- txtai/vectors/dense/llama.py +84 -0
- txtai/vectors/dense/m2v.py +67 -0
- txtai/vectors/dense/sbert.py +92 -0
- txtai/vectors/dense/words.py +211 -0
- txtai/vectors/recovery.py +57 -0
- txtai/vectors/sparse/__init__.py +7 -0
- txtai/vectors/sparse/base.py +90 -0
- txtai/vectors/sparse/factory.py +55 -0
- txtai/vectors/sparse/sbert.py +34 -0
- txtai/version.py +6 -0
- txtai/workflow/__init__.py +8 -0
- txtai/workflow/base.py +184 -0
- txtai/workflow/execute.py +99 -0
- txtai/workflow/factory.py +42 -0
- txtai/workflow/task/__init__.py +18 -0
- txtai/workflow/task/base.py +490 -0
- txtai/workflow/task/console.py +24 -0
- txtai/workflow/task/export.py +64 -0
- txtai/workflow/task/factory.py +89 -0
- txtai/workflow/task/file.py +28 -0
- txtai/workflow/task/image.py +36 -0
- txtai/workflow/task/retrieve.py +61 -0
- txtai/workflow/task/service.py +102 -0
- txtai/workflow/task/storage.py +110 -0
- txtai/workflow/task/stream.py +33 -0
- txtai/workflow/task/template.py +116 -0
- txtai/workflow/task/url.py +20 -0
- txtai/workflow/task/workflow.py +14 -0
@@ -0,0 +1,490 @@
|
|
1
|
+
"""
|
2
|
+
Task module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import re
|
7
|
+
import types
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
|
12
|
+
# Logging configuration
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class Task:
|
17
|
+
"""
|
18
|
+
Base class for all workflow tasks.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
action=None,
|
24
|
+
select=None,
|
25
|
+
unpack=True,
|
26
|
+
column=None,
|
27
|
+
merge="hstack",
|
28
|
+
initialize=None,
|
29
|
+
finalize=None,
|
30
|
+
concurrency=None,
|
31
|
+
onetomany=True,
|
32
|
+
**kwargs,
|
33
|
+
):
|
34
|
+
"""
|
35
|
+
Creates a new task. A task defines two methods, type of data it accepts and the action to execute
|
36
|
+
for each data element. Action is a callable function or list of callable functions.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
action: action(s) to execute on each data element
|
40
|
+
select: filter(s) used to select data to process
|
41
|
+
unpack: if data elements should be unpacked or unwrapped from (id, data, tag) tuples
|
42
|
+
column: column index to select if element is a tuple, defaults to all
|
43
|
+
merge: merge mode for joining multi-action outputs, defaults to hstack
|
44
|
+
initialize: action to execute before processing
|
45
|
+
finalize: action to execute after processing
|
46
|
+
concurrency: sets concurrency method when execute instance available
|
47
|
+
valid values: "thread" for thread-based concurrency, "process" for process-based concurrency
|
48
|
+
onetomany: if one-to-many data transformations should be enabled, defaults to True
|
49
|
+
kwargs: additional keyword arguments
|
50
|
+
"""
|
51
|
+
|
52
|
+
# Standardize into list of actions
|
53
|
+
if not action:
|
54
|
+
action = []
|
55
|
+
elif not isinstance(action, list):
|
56
|
+
action = [action]
|
57
|
+
|
58
|
+
self.action = action
|
59
|
+
self.select = select
|
60
|
+
self.unpack = unpack
|
61
|
+
self.column = column
|
62
|
+
self.merge = merge
|
63
|
+
self.initialize = initialize
|
64
|
+
self.finalize = finalize
|
65
|
+
self.concurrency = concurrency
|
66
|
+
self.onetomany = onetomany
|
67
|
+
|
68
|
+
# Check for custom registration. Adds additional instance members and validates required dependencies available.
|
69
|
+
if hasattr(self, "register"):
|
70
|
+
self.register(**kwargs)
|
71
|
+
elif kwargs:
|
72
|
+
# Raise error if additional keyword arguments passed in without register method
|
73
|
+
kwargs = ", ".join(f"'{kw}'" for kw in kwargs)
|
74
|
+
raise TypeError(f"__init__() got unexpected keyword arguments: {kwargs}")
|
75
|
+
|
76
|
+
def __call__(self, elements, executor=None):
|
77
|
+
"""
|
78
|
+
Executes action for a list of data elements.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
elements: iterable data elements
|
82
|
+
executor: execute instance, enables concurrent task actions
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
transformed data elements
|
86
|
+
"""
|
87
|
+
|
88
|
+
if isinstance(elements, list):
|
89
|
+
return self.filteredrun(elements, executor)
|
90
|
+
|
91
|
+
return self.run(elements, executor)
|
92
|
+
|
93
|
+
def filteredrun(self, elements, executor):
|
94
|
+
"""
|
95
|
+
Executes a filtered run, which will tag all inputs with a process id, filter elements down to elements the
|
96
|
+
task can handle and execute on that subset. Items not selected for processing will be returned unmodified.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
elements: iterable data elements
|
100
|
+
executor: execute instance, enables concurrent task actions
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
transformed data elements
|
104
|
+
"""
|
105
|
+
|
106
|
+
# Build list of elements with unique process ids
|
107
|
+
indexed = list(enumerate(elements))
|
108
|
+
|
109
|
+
# Filter data down to data this task handles
|
110
|
+
data = [(x, self.upack(element)) for x, element in indexed if self.accept(self.upack(element, True))]
|
111
|
+
|
112
|
+
# Get list of filtered process ids
|
113
|
+
ids = [x for x, _ in data]
|
114
|
+
|
115
|
+
# Prepare elements and execute task action(s)
|
116
|
+
results = self.execute([self.prepare(element) for _, element in data], executor)
|
117
|
+
|
118
|
+
# Pack results back into elements
|
119
|
+
if self.merge:
|
120
|
+
elements = self.filteredpack(results, indexed, ids)
|
121
|
+
else:
|
122
|
+
elements = [self.filteredpack(r, indexed, ids) for r in results]
|
123
|
+
|
124
|
+
return elements
|
125
|
+
|
126
|
+
def filteredpack(self, results, indexed, ids):
|
127
|
+
"""
|
128
|
+
Processes and packs results back into original input elements.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
results: task results
|
132
|
+
indexed: original elements indexed by process id
|
133
|
+
ids: process ids accepted by this task
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
packed elements
|
137
|
+
"""
|
138
|
+
|
139
|
+
# Update with transformed elements. Handle one to many transformations.
|
140
|
+
elements = []
|
141
|
+
for x, element in indexed:
|
142
|
+
if x in ids:
|
143
|
+
# Get result for process id
|
144
|
+
result = results[ids.index(x)]
|
145
|
+
|
146
|
+
if isinstance(result, OneToMany):
|
147
|
+
# One to many transformations
|
148
|
+
elements.extend([self.pack(element, r) for r in result])
|
149
|
+
else:
|
150
|
+
# One to one transformations
|
151
|
+
elements.append(self.pack(element, result))
|
152
|
+
else:
|
153
|
+
# Pass unprocessed elements through
|
154
|
+
elements.append(element)
|
155
|
+
|
156
|
+
return elements
|
157
|
+
|
158
|
+
def run(self, elements, executor):
|
159
|
+
"""
|
160
|
+
Executes a task run for elements. A standard run processes all elements.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
elements: iterable data elements
|
164
|
+
executor: execute instance, enables concurrent task actions
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
transformed data elements
|
168
|
+
"""
|
169
|
+
|
170
|
+
# Execute task actions
|
171
|
+
results = self.execute(elements, executor)
|
172
|
+
|
173
|
+
# Handle one to many transformations
|
174
|
+
if isinstance(results, list):
|
175
|
+
elements = []
|
176
|
+
for result in results:
|
177
|
+
if isinstance(result, OneToMany):
|
178
|
+
# One to many transformations
|
179
|
+
elements.extend(result)
|
180
|
+
else:
|
181
|
+
# One to one transformations
|
182
|
+
elements.append(result)
|
183
|
+
|
184
|
+
return elements
|
185
|
+
|
186
|
+
return results
|
187
|
+
|
188
|
+
def accept(self, element):
|
189
|
+
"""
|
190
|
+
Determines if this task can handle the input data format.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
element: input data element
|
194
|
+
|
195
|
+
Returns:
|
196
|
+
True if this task can process this data element, False otherwise
|
197
|
+
"""
|
198
|
+
|
199
|
+
return (isinstance(element, str) and re.search(self.select, element.lower())) if element is not None and self.select else True
|
200
|
+
|
201
|
+
def upack(self, element, force=False):
|
202
|
+
"""
|
203
|
+
Unpacks data for processing.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
element: input data element
|
207
|
+
force: if True, data is unpacked even if task has unpack set to False
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
data
|
211
|
+
"""
|
212
|
+
|
213
|
+
# Extract data from (id, data, tag) formatted elements
|
214
|
+
if (self.unpack or force) and isinstance(element, tuple) and len(element) > 1:
|
215
|
+
return element[1]
|
216
|
+
|
217
|
+
return element
|
218
|
+
|
219
|
+
def pack(self, element, data):
|
220
|
+
"""
|
221
|
+
Packs data after processing.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
element: transformed data element
|
225
|
+
data: item to pack element into
|
226
|
+
|
227
|
+
Returns:
|
228
|
+
packed data
|
229
|
+
"""
|
230
|
+
|
231
|
+
# Pack data into (id, data, tag) formatted elements
|
232
|
+
if self.unpack and isinstance(element, tuple) and len(element) > 1:
|
233
|
+
# If new data is a (id, data, tag) tuple use that except for multi-action "hstack" merges which produce tuples
|
234
|
+
if isinstance(data, tuple) and (len(self.action) <= 1 or self.merge != "hstack"):
|
235
|
+
return data
|
236
|
+
|
237
|
+
# Create a copy of tuple, update data element and return
|
238
|
+
element = list(element)
|
239
|
+
element[1] = data
|
240
|
+
return tuple(element)
|
241
|
+
|
242
|
+
return data
|
243
|
+
|
244
|
+
def prepare(self, element):
|
245
|
+
"""
|
246
|
+
Method that allows downstream tasks to prepare data element for processing.
|
247
|
+
|
248
|
+
Args:
|
249
|
+
element: input data element
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
data element ready for processing
|
253
|
+
"""
|
254
|
+
|
255
|
+
return element
|
256
|
+
|
257
|
+
def execute(self, elements, executor):
|
258
|
+
"""
|
259
|
+
Executes action(s) on elements.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
elements: list of data elements
|
263
|
+
executor: execute instance, enables concurrent task actions
|
264
|
+
|
265
|
+
Returns:
|
266
|
+
transformed data elements
|
267
|
+
"""
|
268
|
+
|
269
|
+
if self.action:
|
270
|
+
# Run actions
|
271
|
+
outputs = []
|
272
|
+
for x, action in enumerate(self.action):
|
273
|
+
# Filter elements by column index if necessary - supports a single int or an action index to column index mapping
|
274
|
+
index = self.column[x] if isinstance(self.column, dict) else self.column
|
275
|
+
inputs = [self.extract(e, index) for e in elements] if index is not None else elements
|
276
|
+
|
277
|
+
# Queue arguments for executor, process immediately if no executor available
|
278
|
+
outputs.append((action, inputs) if executor else self.process(action, inputs))
|
279
|
+
|
280
|
+
# Run with executor if available
|
281
|
+
if executor:
|
282
|
+
outputs = executor.run(self.concurrency, self.process, outputs)
|
283
|
+
|
284
|
+
# Run post process operations
|
285
|
+
return self.postprocess(outputs)
|
286
|
+
|
287
|
+
return elements
|
288
|
+
|
289
|
+
def extract(self, element, index):
|
290
|
+
"""
|
291
|
+
Extracts a column from element by index if the element is a tuple.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
element: input element
|
295
|
+
index: column index
|
296
|
+
|
297
|
+
Returns:
|
298
|
+
extracted column
|
299
|
+
"""
|
300
|
+
|
301
|
+
if isinstance(element, tuple):
|
302
|
+
if not self.unpack and len(element) == 3 and isinstance(element[1], tuple):
|
303
|
+
return (element[0], element[1][index], element[2])
|
304
|
+
|
305
|
+
return element[index]
|
306
|
+
|
307
|
+
return element
|
308
|
+
|
309
|
+
def process(self, action, inputs):
|
310
|
+
"""
|
311
|
+
Executes action using inputs as arguments.
|
312
|
+
|
313
|
+
Args:
|
314
|
+
action: callable object
|
315
|
+
inputs: action inputs
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
action outputs
|
319
|
+
"""
|
320
|
+
|
321
|
+
# Log inputs
|
322
|
+
logger.debug("Inputs: %s", inputs)
|
323
|
+
|
324
|
+
# Execute action and get outputs
|
325
|
+
outputs = action(inputs)
|
326
|
+
|
327
|
+
# Consume generator output, if necessary
|
328
|
+
if isinstance(outputs, types.GeneratorType):
|
329
|
+
outputs = list(outputs)
|
330
|
+
|
331
|
+
# Log outputs
|
332
|
+
logger.debug("Outputs: %s", outputs)
|
333
|
+
|
334
|
+
return outputs
|
335
|
+
|
336
|
+
def postprocess(self, outputs):
|
337
|
+
"""
|
338
|
+
Runs post process routines after a task action.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
outputs: task outputs
|
342
|
+
|
343
|
+
Returns:
|
344
|
+
postprocessed outputs
|
345
|
+
"""
|
346
|
+
|
347
|
+
# Unpack single action tasks
|
348
|
+
if len(self.action) == 1:
|
349
|
+
return self.single(outputs[0])
|
350
|
+
|
351
|
+
# Return unmodified outputs when merge set to None
|
352
|
+
if not self.merge:
|
353
|
+
return outputs
|
354
|
+
|
355
|
+
if self.merge == "vstack":
|
356
|
+
return self.vstack(outputs)
|
357
|
+
if self.merge == "concat":
|
358
|
+
return self.concat(outputs)
|
359
|
+
|
360
|
+
# Default mode is hstack
|
361
|
+
return self.hstack(outputs)
|
362
|
+
|
363
|
+
def single(self, outputs):
|
364
|
+
"""
|
365
|
+
Post processes and returns single action outputs.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
outputs: outputs from a single task
|
369
|
+
|
370
|
+
Returns:
|
371
|
+
post processed outputs
|
372
|
+
"""
|
373
|
+
|
374
|
+
if self.onetomany and isinstance(outputs, list):
|
375
|
+
# Wrap one to many transformations
|
376
|
+
outputs = [OneToMany(output) if isinstance(output, list) else output for output in outputs]
|
377
|
+
|
378
|
+
return outputs
|
379
|
+
|
380
|
+
def vstack(self, outputs):
|
381
|
+
"""
|
382
|
+
Merges outputs row-wise. Returns a list of lists which will be interpreted as a one to many transformation.
|
383
|
+
|
384
|
+
Row-wise merge example (2 actions)
|
385
|
+
|
386
|
+
Inputs: [a, b, c]
|
387
|
+
|
388
|
+
Outputs => [[a1, b1, c1], [a2, b2, c2]]
|
389
|
+
|
390
|
+
Row Merge => [[a1, a2], [b1, b2], [c1, c2]] = [a1, a2, b1, b2, c1, c2]
|
391
|
+
|
392
|
+
Args:
|
393
|
+
outputs: task outputs
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
list of aggregated/zipped outputs as one to many transforms (row-wise)
|
397
|
+
"""
|
398
|
+
|
399
|
+
# If all outputs are numpy arrays, use native method
|
400
|
+
if all(isinstance(output, np.ndarray) for output in outputs):
|
401
|
+
return np.concatenate(np.stack(outputs, axis=1))
|
402
|
+
|
403
|
+
# If all outputs are torch tensors, use native method
|
404
|
+
# pylint: disable=E1101
|
405
|
+
if all(torch.is_tensor(output) for output in outputs):
|
406
|
+
return torch.cat(tuple(torch.stack(outputs, axis=1)))
|
407
|
+
|
408
|
+
# Flatten into lists of outputs per input row. Wrap as one to many transformation.
|
409
|
+
merge = []
|
410
|
+
for x in zip(*outputs):
|
411
|
+
combine = []
|
412
|
+
for y in x:
|
413
|
+
if isinstance(y, list):
|
414
|
+
combine.extend(y)
|
415
|
+
else:
|
416
|
+
combine.append(y)
|
417
|
+
|
418
|
+
merge.append(OneToMany(combine))
|
419
|
+
|
420
|
+
return merge
|
421
|
+
|
422
|
+
def hstack(self, outputs):
|
423
|
+
"""
|
424
|
+
Merges outputs column-wise. Returns a list of tuples which will be interpreted as a one to one transformation.
|
425
|
+
|
426
|
+
Column-wise merge example (2 actions)
|
427
|
+
|
428
|
+
Inputs: [a, b, c]
|
429
|
+
|
430
|
+
Outputs => [[a1, b1, c1], [a2, b2, c2]]
|
431
|
+
|
432
|
+
Column Merge => [(a1, a2), (b1, b2), (c1, c2)]
|
433
|
+
|
434
|
+
Args:
|
435
|
+
outputs: task outputs
|
436
|
+
|
437
|
+
Returns:
|
438
|
+
list of aggregated/zipped outputs as tuples (column-wise)
|
439
|
+
"""
|
440
|
+
|
441
|
+
# If all outputs are numpy arrays, use native method
|
442
|
+
if all(isinstance(output, np.ndarray) for output in outputs):
|
443
|
+
return np.stack(outputs, axis=1)
|
444
|
+
|
445
|
+
# If all outputs are torch tensors, use native method
|
446
|
+
# pylint: disable=E1101
|
447
|
+
if all(torch.is_tensor(output) for output in outputs):
|
448
|
+
return torch.stack(outputs, axis=1)
|
449
|
+
|
450
|
+
return list(zip(*outputs))
|
451
|
+
|
452
|
+
def concat(self, outputs):
|
453
|
+
"""
|
454
|
+
Merges outputs column-wise and concats values together into a string. Returns a list of strings.
|
455
|
+
|
456
|
+
Concat merge example (2 actions)
|
457
|
+
|
458
|
+
Inputs: [a, b, c]
|
459
|
+
|
460
|
+
Outputs => [[a1, b1, c1], [a2, b2, c2]]
|
461
|
+
|
462
|
+
Concat Merge => [(a1, a2), (b1, b2), (c1, c2)] => ["a1. a2", "b1. b2", "c1. c2"]
|
463
|
+
|
464
|
+
Args:
|
465
|
+
outputs: task outputs
|
466
|
+
|
467
|
+
Returns:
|
468
|
+
list of concat outputs
|
469
|
+
"""
|
470
|
+
|
471
|
+
return [". ".join([str(y) for y in x if y]) for x in self.hstack(outputs)]
|
472
|
+
|
473
|
+
|
474
|
+
class OneToMany:
|
475
|
+
"""
|
476
|
+
Encapsulates list output for a one to many transformation.
|
477
|
+
"""
|
478
|
+
|
479
|
+
def __init__(self, values):
|
480
|
+
"""
|
481
|
+
Creates a new OneToMany transformation.
|
482
|
+
|
483
|
+
Args:
|
484
|
+
values: list of outputs
|
485
|
+
"""
|
486
|
+
|
487
|
+
self.values = values
|
488
|
+
|
489
|
+
def __iter__(self):
|
490
|
+
return self.values.__iter__()
|
@@ -0,0 +1,24 @@
|
|
1
|
+
"""
|
2
|
+
ConsoleTask module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import json
|
6
|
+
|
7
|
+
from .base import Task
|
8
|
+
|
9
|
+
|
10
|
+
class ConsoleTask(Task):
|
11
|
+
"""
|
12
|
+
Task that prints task elements to the console.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __call__(self, elements, executor=None):
|
16
|
+
# Run task
|
17
|
+
outputs = super().__call__(elements, executor)
|
18
|
+
|
19
|
+
# Print inputs and outputs to console
|
20
|
+
print("Inputs:", json.dumps(elements, indent=2))
|
21
|
+
print("Outputs:", json.dumps(outputs, indent=2))
|
22
|
+
|
23
|
+
# Return results
|
24
|
+
return outputs
|
@@ -0,0 +1,64 @@
|
|
1
|
+
"""
|
2
|
+
ExportTask module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import datetime
|
6
|
+
import os
|
7
|
+
|
8
|
+
# Conditional import
|
9
|
+
try:
|
10
|
+
import pandas as pd
|
11
|
+
|
12
|
+
PANDAS = True
|
13
|
+
except ImportError:
|
14
|
+
PANDAS = False
|
15
|
+
|
16
|
+
from .base import Task
|
17
|
+
|
18
|
+
|
19
|
+
class ExportTask(Task):
|
20
|
+
"""
|
21
|
+
Task that exports task elements using Pandas.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def register(self, output=None, timestamp=None):
|
25
|
+
"""
|
26
|
+
Add export parameters to task. Checks if required dependencies are installed.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
output: output file path
|
30
|
+
timestamp: true if output file should be timestamped
|
31
|
+
"""
|
32
|
+
|
33
|
+
if not PANDAS:
|
34
|
+
raise ImportError('ExportTask is not available - install "workflow" extra to enable')
|
35
|
+
|
36
|
+
# pylint: disable=W0201
|
37
|
+
self.output = output
|
38
|
+
self.timestamp = timestamp
|
39
|
+
|
40
|
+
def __call__(self, elements, executor=None):
|
41
|
+
# Run task
|
42
|
+
outputs = super().__call__(elements, executor)
|
43
|
+
|
44
|
+
# Get output file extension
|
45
|
+
output = self.output
|
46
|
+
parts = list(os.path.splitext(output))
|
47
|
+
extension = parts[-1].lower()
|
48
|
+
|
49
|
+
# Add timestamp to filename
|
50
|
+
if self.timestamp:
|
51
|
+
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
52
|
+
parts[-1] = timestamp + parts[-1]
|
53
|
+
|
54
|
+
# Create full path to output file
|
55
|
+
output = ".".join(parts)
|
56
|
+
|
57
|
+
# Write output
|
58
|
+
if extension == ".xlsx":
|
59
|
+
pd.DataFrame(outputs).to_excel(output, index=False)
|
60
|
+
else:
|
61
|
+
pd.DataFrame(outputs).to_csv(output, index=False)
|
62
|
+
|
63
|
+
# Return results
|
64
|
+
return outputs
|
@@ -0,0 +1,89 @@
|
|
1
|
+
"""
|
2
|
+
Task factory module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import functools
|
6
|
+
|
7
|
+
from ...util import Resolver
|
8
|
+
|
9
|
+
|
10
|
+
class TaskFactory:
|
11
|
+
"""
|
12
|
+
Task factory. Creates new Task instances.
|
13
|
+
"""
|
14
|
+
|
15
|
+
@staticmethod
|
16
|
+
def get(task):
|
17
|
+
"""
|
18
|
+
Gets a new instance of task class.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
task: Task instance class
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Task class
|
25
|
+
"""
|
26
|
+
|
27
|
+
# Local task if no package
|
28
|
+
if "." not in task:
|
29
|
+
# Get parent package
|
30
|
+
task = ".".join(__name__.split(".")[:-1]) + "." + task.capitalize() + "Task"
|
31
|
+
|
32
|
+
# Attempt to load custom task
|
33
|
+
return Resolver()(task)
|
34
|
+
|
35
|
+
@staticmethod
|
36
|
+
def create(config, task):
|
37
|
+
"""
|
38
|
+
Creates a new Task instance.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
config: Task configuration
|
42
|
+
task: Task instance class
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
Task
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Create lambda function if additional arguments present
|
49
|
+
if "args" in config:
|
50
|
+
args = config.pop("args")
|
51
|
+
action = config["action"]
|
52
|
+
if action:
|
53
|
+
if isinstance(action, list):
|
54
|
+
config["action"] = [Partial.create(a, args[i]) for i, a in enumerate(action)]
|
55
|
+
else:
|
56
|
+
# Accept keyword or positional arguments
|
57
|
+
config["action"] = lambda x: action(x, **args) if isinstance(args, dict) else action(x, *args)
|
58
|
+
|
59
|
+
# Get Task instance
|
60
|
+
return TaskFactory.get(task)(**config)
|
61
|
+
|
62
|
+
|
63
|
+
class Partial(functools.partial):
|
64
|
+
"""
|
65
|
+
Modifies functools.partial to prepend arguments vs append.
|
66
|
+
"""
|
67
|
+
|
68
|
+
@staticmethod
|
69
|
+
def create(action, args):
|
70
|
+
"""
|
71
|
+
Creates a new Partial function.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
action: action to execute
|
75
|
+
args: arguments
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Partial
|
79
|
+
"""
|
80
|
+
|
81
|
+
return Partial(action, **args) if isinstance(args, dict) else Partial(action, *args) if args else Partial(action)
|
82
|
+
|
83
|
+
def __call__(self, *args, **kwargs):
|
84
|
+
# Update keyword arguments
|
85
|
+
kw = self.keywords.copy()
|
86
|
+
kw.update(kwargs)
|
87
|
+
|
88
|
+
# Execute function with new arguments prepended to default arguments
|
89
|
+
return self.func(*(args + self.args), **kw)
|
@@ -0,0 +1,28 @@
|
|
1
|
+
"""
|
2
|
+
FileTask module
|
3
|
+
"""
|
4
|
+
|
5
|
+
import os
|
6
|
+
import re
|
7
|
+
|
8
|
+
from .base import Task
|
9
|
+
|
10
|
+
|
11
|
+
class FileTask(Task):
|
12
|
+
"""
|
13
|
+
Task that processes file paths
|
14
|
+
"""
|
15
|
+
|
16
|
+
# File prefix
|
17
|
+
FILE = r"file:\/\/"
|
18
|
+
|
19
|
+
def accept(self, element):
|
20
|
+
# Replace file prefixes
|
21
|
+
element = re.sub(FileTask.FILE, "", element)
|
22
|
+
|
23
|
+
# Only accept file paths that exist
|
24
|
+
return super().accept(element) and isinstance(element, str) and os.path.exists(element)
|
25
|
+
|
26
|
+
def prepare(self, element):
|
27
|
+
# Replace file prefixes
|
28
|
+
return re.sub(FileTask.FILE, "", element)
|