mseep-txtai 9.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. mseep_txtai-9.1.1.dist-info/METADATA +262 -0
  2. mseep_txtai-9.1.1.dist-info/RECORD +251 -0
  3. mseep_txtai-9.1.1.dist-info/WHEEL +5 -0
  4. mseep_txtai-9.1.1.dist-info/licenses/LICENSE +190 -0
  5. mseep_txtai-9.1.1.dist-info/top_level.txt +1 -0
  6. txtai/__init__.py +16 -0
  7. txtai/agent/__init__.py +12 -0
  8. txtai/agent/base.py +54 -0
  9. txtai/agent/factory.py +39 -0
  10. txtai/agent/model.py +107 -0
  11. txtai/agent/placeholder.py +16 -0
  12. txtai/agent/tool/__init__.py +7 -0
  13. txtai/agent/tool/embeddings.py +69 -0
  14. txtai/agent/tool/factory.py +130 -0
  15. txtai/agent/tool/function.py +49 -0
  16. txtai/ann/__init__.py +7 -0
  17. txtai/ann/base.py +153 -0
  18. txtai/ann/dense/__init__.py +11 -0
  19. txtai/ann/dense/annoy.py +72 -0
  20. txtai/ann/dense/factory.py +76 -0
  21. txtai/ann/dense/faiss.py +233 -0
  22. txtai/ann/dense/hnsw.py +104 -0
  23. txtai/ann/dense/numpy.py +164 -0
  24. txtai/ann/dense/pgvector.py +323 -0
  25. txtai/ann/dense/sqlite.py +303 -0
  26. txtai/ann/dense/torch.py +38 -0
  27. txtai/ann/sparse/__init__.py +7 -0
  28. txtai/ann/sparse/factory.py +61 -0
  29. txtai/ann/sparse/ivfsparse.py +377 -0
  30. txtai/ann/sparse/pgsparse.py +56 -0
  31. txtai/api/__init__.py +18 -0
  32. txtai/api/application.py +134 -0
  33. txtai/api/authorization.py +53 -0
  34. txtai/api/base.py +159 -0
  35. txtai/api/cluster.py +295 -0
  36. txtai/api/extension.py +19 -0
  37. txtai/api/factory.py +40 -0
  38. txtai/api/responses/__init__.py +7 -0
  39. txtai/api/responses/factory.py +30 -0
  40. txtai/api/responses/json.py +56 -0
  41. txtai/api/responses/messagepack.py +51 -0
  42. txtai/api/route.py +41 -0
  43. txtai/api/routers/__init__.py +25 -0
  44. txtai/api/routers/agent.py +38 -0
  45. txtai/api/routers/caption.py +42 -0
  46. txtai/api/routers/embeddings.py +280 -0
  47. txtai/api/routers/entity.py +42 -0
  48. txtai/api/routers/extractor.py +28 -0
  49. txtai/api/routers/labels.py +47 -0
  50. txtai/api/routers/llm.py +61 -0
  51. txtai/api/routers/objects.py +42 -0
  52. txtai/api/routers/openai.py +191 -0
  53. txtai/api/routers/rag.py +61 -0
  54. txtai/api/routers/reranker.py +46 -0
  55. txtai/api/routers/segmentation.py +42 -0
  56. txtai/api/routers/similarity.py +48 -0
  57. txtai/api/routers/summary.py +46 -0
  58. txtai/api/routers/tabular.py +42 -0
  59. txtai/api/routers/textractor.py +42 -0
  60. txtai/api/routers/texttospeech.py +33 -0
  61. txtai/api/routers/transcription.py +42 -0
  62. txtai/api/routers/translation.py +46 -0
  63. txtai/api/routers/upload.py +36 -0
  64. txtai/api/routers/workflow.py +28 -0
  65. txtai/app/__init__.py +5 -0
  66. txtai/app/base.py +821 -0
  67. txtai/archive/__init__.py +9 -0
  68. txtai/archive/base.py +104 -0
  69. txtai/archive/compress.py +51 -0
  70. txtai/archive/factory.py +25 -0
  71. txtai/archive/tar.py +49 -0
  72. txtai/archive/zip.py +35 -0
  73. txtai/cloud/__init__.py +8 -0
  74. txtai/cloud/base.py +106 -0
  75. txtai/cloud/factory.py +70 -0
  76. txtai/cloud/hub.py +101 -0
  77. txtai/cloud/storage.py +125 -0
  78. txtai/console/__init__.py +5 -0
  79. txtai/console/__main__.py +22 -0
  80. txtai/console/base.py +264 -0
  81. txtai/data/__init__.py +10 -0
  82. txtai/data/base.py +138 -0
  83. txtai/data/labels.py +42 -0
  84. txtai/data/questions.py +135 -0
  85. txtai/data/sequences.py +48 -0
  86. txtai/data/texts.py +68 -0
  87. txtai/data/tokens.py +28 -0
  88. txtai/database/__init__.py +14 -0
  89. txtai/database/base.py +342 -0
  90. txtai/database/client.py +227 -0
  91. txtai/database/duckdb.py +150 -0
  92. txtai/database/embedded.py +76 -0
  93. txtai/database/encoder/__init__.py +8 -0
  94. txtai/database/encoder/base.py +37 -0
  95. txtai/database/encoder/factory.py +56 -0
  96. txtai/database/encoder/image.py +43 -0
  97. txtai/database/encoder/serialize.py +28 -0
  98. txtai/database/factory.py +77 -0
  99. txtai/database/rdbms.py +569 -0
  100. txtai/database/schema/__init__.py +6 -0
  101. txtai/database/schema/orm.py +99 -0
  102. txtai/database/schema/statement.py +98 -0
  103. txtai/database/sql/__init__.py +8 -0
  104. txtai/database/sql/aggregate.py +178 -0
  105. txtai/database/sql/base.py +189 -0
  106. txtai/database/sql/expression.py +404 -0
  107. txtai/database/sql/token.py +342 -0
  108. txtai/database/sqlite.py +57 -0
  109. txtai/embeddings/__init__.py +7 -0
  110. txtai/embeddings/base.py +1107 -0
  111. txtai/embeddings/index/__init__.py +14 -0
  112. txtai/embeddings/index/action.py +15 -0
  113. txtai/embeddings/index/autoid.py +92 -0
  114. txtai/embeddings/index/configuration.py +71 -0
  115. txtai/embeddings/index/documents.py +86 -0
  116. txtai/embeddings/index/functions.py +155 -0
  117. txtai/embeddings/index/indexes.py +199 -0
  118. txtai/embeddings/index/indexids.py +60 -0
  119. txtai/embeddings/index/reducer.py +104 -0
  120. txtai/embeddings/index/stream.py +67 -0
  121. txtai/embeddings/index/transform.py +205 -0
  122. txtai/embeddings/search/__init__.py +11 -0
  123. txtai/embeddings/search/base.py +344 -0
  124. txtai/embeddings/search/errors.py +9 -0
  125. txtai/embeddings/search/explain.py +120 -0
  126. txtai/embeddings/search/ids.py +61 -0
  127. txtai/embeddings/search/query.py +69 -0
  128. txtai/embeddings/search/scan.py +196 -0
  129. txtai/embeddings/search/terms.py +46 -0
  130. txtai/graph/__init__.py +10 -0
  131. txtai/graph/base.py +769 -0
  132. txtai/graph/factory.py +61 -0
  133. txtai/graph/networkx.py +275 -0
  134. txtai/graph/query.py +181 -0
  135. txtai/graph/rdbms.py +113 -0
  136. txtai/graph/topics.py +166 -0
  137. txtai/models/__init__.py +9 -0
  138. txtai/models/models.py +268 -0
  139. txtai/models/onnx.py +133 -0
  140. txtai/models/pooling/__init__.py +9 -0
  141. txtai/models/pooling/base.py +141 -0
  142. txtai/models/pooling/cls.py +28 -0
  143. txtai/models/pooling/factory.py +144 -0
  144. txtai/models/pooling/late.py +173 -0
  145. txtai/models/pooling/mean.py +33 -0
  146. txtai/models/pooling/muvera.py +164 -0
  147. txtai/models/registry.py +37 -0
  148. txtai/models/tokendetection.py +122 -0
  149. txtai/pipeline/__init__.py +17 -0
  150. txtai/pipeline/audio/__init__.py +11 -0
  151. txtai/pipeline/audio/audiomixer.py +58 -0
  152. txtai/pipeline/audio/audiostream.py +94 -0
  153. txtai/pipeline/audio/microphone.py +244 -0
  154. txtai/pipeline/audio/signal.py +186 -0
  155. txtai/pipeline/audio/texttoaudio.py +60 -0
  156. txtai/pipeline/audio/texttospeech.py +553 -0
  157. txtai/pipeline/audio/transcription.py +212 -0
  158. txtai/pipeline/base.py +23 -0
  159. txtai/pipeline/data/__init__.py +10 -0
  160. txtai/pipeline/data/filetohtml.py +206 -0
  161. txtai/pipeline/data/htmltomd.py +414 -0
  162. txtai/pipeline/data/segmentation.py +178 -0
  163. txtai/pipeline/data/tabular.py +155 -0
  164. txtai/pipeline/data/textractor.py +139 -0
  165. txtai/pipeline/data/tokenizer.py +112 -0
  166. txtai/pipeline/factory.py +77 -0
  167. txtai/pipeline/hfmodel.py +111 -0
  168. txtai/pipeline/hfpipeline.py +96 -0
  169. txtai/pipeline/image/__init__.py +7 -0
  170. txtai/pipeline/image/caption.py +55 -0
  171. txtai/pipeline/image/imagehash.py +90 -0
  172. txtai/pipeline/image/objects.py +80 -0
  173. txtai/pipeline/llm/__init__.py +11 -0
  174. txtai/pipeline/llm/factory.py +86 -0
  175. txtai/pipeline/llm/generation.py +173 -0
  176. txtai/pipeline/llm/huggingface.py +218 -0
  177. txtai/pipeline/llm/litellm.py +90 -0
  178. txtai/pipeline/llm/llama.py +152 -0
  179. txtai/pipeline/llm/llm.py +75 -0
  180. txtai/pipeline/llm/rag.py +477 -0
  181. txtai/pipeline/nop.py +14 -0
  182. txtai/pipeline/tensors.py +52 -0
  183. txtai/pipeline/text/__init__.py +13 -0
  184. txtai/pipeline/text/crossencoder.py +70 -0
  185. txtai/pipeline/text/entity.py +140 -0
  186. txtai/pipeline/text/labels.py +137 -0
  187. txtai/pipeline/text/lateencoder.py +103 -0
  188. txtai/pipeline/text/questions.py +48 -0
  189. txtai/pipeline/text/reranker.py +57 -0
  190. txtai/pipeline/text/similarity.py +83 -0
  191. txtai/pipeline/text/summary.py +98 -0
  192. txtai/pipeline/text/translation.py +298 -0
  193. txtai/pipeline/train/__init__.py +7 -0
  194. txtai/pipeline/train/hfonnx.py +196 -0
  195. txtai/pipeline/train/hftrainer.py +398 -0
  196. txtai/pipeline/train/mlonnx.py +63 -0
  197. txtai/scoring/__init__.py +12 -0
  198. txtai/scoring/base.py +188 -0
  199. txtai/scoring/bm25.py +29 -0
  200. txtai/scoring/factory.py +95 -0
  201. txtai/scoring/pgtext.py +181 -0
  202. txtai/scoring/sif.py +32 -0
  203. txtai/scoring/sparse.py +218 -0
  204. txtai/scoring/terms.py +499 -0
  205. txtai/scoring/tfidf.py +358 -0
  206. txtai/serialize/__init__.py +10 -0
  207. txtai/serialize/base.py +85 -0
  208. txtai/serialize/errors.py +9 -0
  209. txtai/serialize/factory.py +29 -0
  210. txtai/serialize/messagepack.py +42 -0
  211. txtai/serialize/pickle.py +98 -0
  212. txtai/serialize/serializer.py +46 -0
  213. txtai/util/__init__.py +7 -0
  214. txtai/util/resolver.py +32 -0
  215. txtai/util/sparsearray.py +62 -0
  216. txtai/util/template.py +16 -0
  217. txtai/vectors/__init__.py +8 -0
  218. txtai/vectors/base.py +476 -0
  219. txtai/vectors/dense/__init__.py +12 -0
  220. txtai/vectors/dense/external.py +55 -0
  221. txtai/vectors/dense/factory.py +121 -0
  222. txtai/vectors/dense/huggingface.py +44 -0
  223. txtai/vectors/dense/litellm.py +86 -0
  224. txtai/vectors/dense/llama.py +84 -0
  225. txtai/vectors/dense/m2v.py +67 -0
  226. txtai/vectors/dense/sbert.py +92 -0
  227. txtai/vectors/dense/words.py +211 -0
  228. txtai/vectors/recovery.py +57 -0
  229. txtai/vectors/sparse/__init__.py +7 -0
  230. txtai/vectors/sparse/base.py +90 -0
  231. txtai/vectors/sparse/factory.py +55 -0
  232. txtai/vectors/sparse/sbert.py +34 -0
  233. txtai/version.py +6 -0
  234. txtai/workflow/__init__.py +8 -0
  235. txtai/workflow/base.py +184 -0
  236. txtai/workflow/execute.py +99 -0
  237. txtai/workflow/factory.py +42 -0
  238. txtai/workflow/task/__init__.py +18 -0
  239. txtai/workflow/task/base.py +490 -0
  240. txtai/workflow/task/console.py +24 -0
  241. txtai/workflow/task/export.py +64 -0
  242. txtai/workflow/task/factory.py +89 -0
  243. txtai/workflow/task/file.py +28 -0
  244. txtai/workflow/task/image.py +36 -0
  245. txtai/workflow/task/retrieve.py +61 -0
  246. txtai/workflow/task/service.py +102 -0
  247. txtai/workflow/task/storage.py +110 -0
  248. txtai/workflow/task/stream.py +33 -0
  249. txtai/workflow/task/template.py +116 -0
  250. txtai/workflow/task/url.py +20 -0
  251. txtai/workflow/task/workflow.py +14 -0
txtai/api/base.py ADDED
@@ -0,0 +1,159 @@
1
+ """
2
+ API module
3
+ """
4
+
5
+ import json
6
+
7
+ from .cluster import Cluster
8
+
9
+ from ..app import Application
10
+
11
+
12
+ class API(Application):
13
+ """
14
+ Base API template. The API is an extended txtai application, adding the ability to cluster API instances together.
15
+
16
+ Downstream applications can extend this base template to add/modify functionality.
17
+ """
18
+
19
+ def __init__(self, config, loaddata=True):
20
+ super().__init__(config, loaddata)
21
+
22
+ # Embeddings cluster
23
+ self.cluster = None
24
+ if self.config.get("cluster"):
25
+ self.cluster = Cluster(self.config["cluster"])
26
+
27
+ # pylint: disable=W0221
28
+ def search(self, query, limit=None, weights=None, index=None, parameters=None, graph=False, request=None):
29
+ # When search is invoked via the API, limit is set from the request
30
+ # When search is invoked directly, limit is set using the method parameter
31
+ limit = self.limit(request.query_params.get("limit") if request and hasattr(request, "query_params") else limit)
32
+ weights = self.weights(request.query_params.get("weights") if request and hasattr(request, "query_params") else weights)
33
+ index = request.query_params.get("index") if request and hasattr(request, "query_params") else index
34
+ parameters = request.query_params.get("parameters") if request and hasattr(request, "query_params") else parameters
35
+ graph = request.query_params.get("graph") if request and hasattr(request, "query_params") else graph
36
+
37
+ # Decode parameters
38
+ parameters = json.loads(parameters) if parameters and isinstance(parameters, str) else parameters
39
+
40
+ if self.cluster:
41
+ return self.cluster.search(query, limit, weights, index, parameters, graph)
42
+
43
+ return super().search(query, limit, weights, index, parameters, graph)
44
+
45
+ def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None, graph=False):
46
+ if self.cluster:
47
+ return self.cluster.batchsearch(queries, self.limit(limit), weights, index, parameters, graph)
48
+
49
+ return super().batchsearch(queries, limit, weights, index, parameters, graph)
50
+
51
+ def add(self, documents):
52
+ """
53
+ Adds a batch of documents for indexing.
54
+
55
+ Downstream applications can override this method to also store full documents in an external system.
56
+
57
+ Args:
58
+ documents: list of {id: value, text: value}
59
+
60
+ Returns:
61
+ unmodified input documents
62
+ """
63
+
64
+ if self.cluster:
65
+ self.cluster.add(documents)
66
+ else:
67
+ super().add(documents)
68
+
69
+ return documents
70
+
71
+ def index(self):
72
+ """
73
+ Builds an embeddings index for previously batched documents.
74
+ """
75
+
76
+ if self.cluster:
77
+ self.cluster.index()
78
+ else:
79
+ super().index()
80
+
81
+ def upsert(self):
82
+ """
83
+ Runs an embeddings upsert operation for previously batched documents.
84
+ """
85
+
86
+ if self.cluster:
87
+ self.cluster.upsert()
88
+ else:
89
+ super().upsert()
90
+
91
+ def delete(self, ids):
92
+ """
93
+ Deletes from an embeddings index. Returns list of ids deleted.
94
+
95
+ Args:
96
+ ids: list of ids to delete
97
+
98
+ Returns:
99
+ ids deleted
100
+ """
101
+
102
+ if self.cluster:
103
+ return self.cluster.delete(ids)
104
+
105
+ return super().delete(ids)
106
+
107
+ def reindex(self, config, function=None):
108
+ """
109
+ Recreates this embeddings index using config. This method only works if document content storage is enabled.
110
+
111
+ Args:
112
+ config: new config
113
+ function: optional function to prepare content for indexing
114
+ """
115
+
116
+ if self.cluster:
117
+ self.cluster.reindex(config, function)
118
+ else:
119
+ super().reindex(config, function)
120
+
121
+ def count(self):
122
+ """
123
+ Total number of elements in this embeddings index.
124
+
125
+ Returns:
126
+ number of elements in embeddings index
127
+ """
128
+
129
+ if self.cluster:
130
+ return self.cluster.count()
131
+
132
+ return super().count()
133
+
134
+ def limit(self, limit):
135
+ """
136
+ Parses the number of results to return from the request. Allows range of 1-250, with a default of 10.
137
+
138
+ Args:
139
+ limit: limit parameter
140
+
141
+ Returns:
142
+ bounded limit
143
+ """
144
+
145
+ # Return between 1 and 250 results, defaults to 10
146
+ return max(1, min(250, int(limit) if limit else 10))
147
+
148
+ def weights(self, weights):
149
+ """
150
+ Parses the weights parameter from the request.
151
+
152
+ Args:
153
+ weights: weights parameter
154
+
155
+ Returns:
156
+ weights
157
+ """
158
+
159
+ return float(weights) if weights else weights
txtai/api/cluster.py ADDED
@@ -0,0 +1,295 @@
1
+ """
2
+ Cluster module
3
+ """
4
+
5
+ import asyncio
6
+ import json
7
+ import random
8
+ import urllib.parse
9
+ import zlib
10
+
11
+ import aiohttp
12
+
13
+ from ..database.sql import Aggregate
14
+
15
+
16
+ class Cluster:
17
+ """
18
+ Aggregates multiple embeddings shards into a single logical embeddings instance.
19
+ """
20
+
21
+ # pylint: disable = W0231
22
+ def __init__(self, config=None):
23
+ """
24
+ Creates a new Cluster.
25
+
26
+ Args:
27
+ config: cluster configuration
28
+ """
29
+
30
+ # Configuration
31
+ self.config = config
32
+
33
+ # Embeddings shard urls
34
+ self.shards = None
35
+ if "shards" in self.config:
36
+ self.shards = self.config["shards"]
37
+
38
+ # Query aggregator
39
+ self.aggregate = Aggregate()
40
+
41
+ def search(self, query, limit=None, weights=None, index=None, parameters=None, graph=False):
42
+ """
43
+ Finds documents most similar to the input query. This method will run either an index search
44
+ or an index + database search depending on if a database is available.
45
+
46
+ Args:
47
+ query: input query
48
+ limit: maximum results
49
+ weights: hybrid score weights, if applicable
50
+ index: index name, if applicable
51
+ parameters: dict of named parameters to bind to placeholders
52
+ graph: return graph results if True
53
+
54
+ Returns:
55
+ list of {id: value, score: value} for index search, list of dict for an index + database search
56
+ """
57
+
58
+ # Build URL
59
+ action = f"search?query={urllib.parse.quote_plus(query)}"
60
+ if limit:
61
+ action += f"&limit={limit}"
62
+ if weights:
63
+ action += f"&weights={weights}"
64
+ if index:
65
+ action += f"&index={index}"
66
+ if parameters:
67
+ action += f"&parameters={json.dumps(parameters) if isinstance(parameters, dict) else parameters}"
68
+ if graph is not None:
69
+ action += f"&graph={graph}"
70
+
71
+ # Run query and flatten results into single results list
72
+ results = []
73
+ for result in self.execute("get", action):
74
+ results.extend(result)
75
+
76
+ # Combine aggregate functions and sort
77
+ results = self.aggregate(query, results)
78
+
79
+ # Limit results
80
+ return results[: (limit if limit else 10)]
81
+
82
+ def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None, graph=False):
83
+ """
84
+ Finds documents most similar to the input queries. This method will run either an index search
85
+ or an index + database search depending on if a database is available.
86
+
87
+ Args:
88
+ queries: input queries
89
+ limit: maximum results
90
+ weights: hybrid score weights, if applicable
91
+ index: index name, if applicable
92
+ parameters: list of dicts of named parameters to bind to placeholders
93
+ graph: return graph results if True
94
+
95
+ Returns:
96
+ list of {id: value, score: value} per query for index search, list of dict per query for an index + database search
97
+ """
98
+
99
+ # POST parameters
100
+ params = {"queries": queries}
101
+ if limit:
102
+ params["limit"] = limit
103
+ if weights:
104
+ params["weights"] = weights
105
+ if index:
106
+ params["index"] = index
107
+ if parameters:
108
+ params["parameters"] = parameters
109
+ if graph is not None:
110
+ params["graph"] = graph
111
+
112
+ # Run query
113
+ batch = self.execute("post", "batchsearch", [params] * len(self.shards))
114
+
115
+ # Combine results per query
116
+ results = []
117
+ for x, query in enumerate(queries):
118
+ result = []
119
+ for section in batch:
120
+ result.extend(section[x])
121
+
122
+ # Aggregate, sort and limit results
123
+ results.append(self.aggregate(query, result)[: (limit if limit else 10)])
124
+
125
+ return results
126
+
127
+ def add(self, documents):
128
+ """
129
+ Adds a batch of documents for indexing.
130
+
131
+ Args:
132
+ documents: list of {id: value, text: value}
133
+ """
134
+
135
+ self.execute("post", "add", self.shard(documents))
136
+
137
+ def index(self):
138
+ """
139
+ Builds an embeddings index for previously batched documents.
140
+ """
141
+
142
+ self.execute("get", "index")
143
+
144
+ def upsert(self):
145
+ """
146
+ Runs an embeddings upsert operation for previously batched documents.
147
+ """
148
+
149
+ self.execute("get", "upsert")
150
+
151
+ def delete(self, ids):
152
+ """
153
+ Deletes from an embeddings cluster. Returns list of ids deleted.
154
+
155
+ Args:
156
+ ids: list of ids to delete
157
+
158
+ Returns:
159
+ ids deleted
160
+ """
161
+
162
+ return [uid for ids in self.execute("post", "delete", [ids] * len(self.shards)) for uid in ids]
163
+
164
+ def reindex(self, config, function=None):
165
+ """
166
+ Recreates this embeddings index using config. This method only works if document content storage is enabled.
167
+
168
+ Args:
169
+ config: new config
170
+ function: optional function to prepare content for indexing
171
+ """
172
+
173
+ self.execute("post", "reindex", [{"config": config, "function": function}] * len(self.shards))
174
+
175
+ def count(self):
176
+ """
177
+ Total number of elements in this embeddings cluster.
178
+
179
+ Returns:
180
+ number of elements in embeddings cluster
181
+ """
182
+
183
+ return sum(self.execute("get", "count"))
184
+
185
+ def shard(self, documents):
186
+ """
187
+ Splits documents into equal sized shards.
188
+
189
+ Args:
190
+ documents: input documents
191
+
192
+ Returns:
193
+ list of evenly sized shards with the last shard having the remaining elements
194
+ """
195
+
196
+ shards = [[] for _ in range(len(self.shards))]
197
+ for document in documents:
198
+ uid = document.get("id") if isinstance(document, dict) else document
199
+ if uid and isinstance(uid, str):
200
+ # Quick int hash of string to help derive shard id
201
+ uid = zlib.adler32(uid.encode("utf-8"))
202
+ elif uid is None:
203
+ # Get random shard id when uid isn't set
204
+ uid = random.randint(0, len(shards) - 1)
205
+
206
+ shards[uid % len(self.shards)].append(document)
207
+
208
+ return shards
209
+
210
+ def execute(self, method, action, data=None):
211
+ """
212
+ Executes a HTTP action asynchronously.
213
+
214
+ Args:
215
+ method: get or post
216
+ action: url action to perform
217
+ data: post parameters
218
+
219
+ Returns:
220
+ json results if any
221
+ """
222
+
223
+ # Get urls
224
+ urls = [f"{shard}/{action}" for shard in self.shards]
225
+ close = False
226
+
227
+ # Use existing loop if available, otherwise create one
228
+ try:
229
+ loop = asyncio.get_event_loop()
230
+ except RuntimeError:
231
+ loop = asyncio.new_event_loop()
232
+ close = True
233
+
234
+ try:
235
+ return loop.run_until_complete(self.run(urls, method, data))
236
+ finally:
237
+ # Close loop if it was created in this method
238
+ if close:
239
+ loop.close()
240
+
241
+ async def run(self, urls, method, data):
242
+ """
243
+ Runs an async action.
244
+
245
+ Args:
246
+ urls: run against this list of urls
247
+ method: get or post
248
+ data: list of data for each url or None
249
+
250
+ Returns:
251
+ json results if any
252
+ """
253
+
254
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
255
+ tasks = []
256
+
257
+ for x, url in enumerate(urls):
258
+ if method == "post":
259
+ if not data or data[x]:
260
+ tasks.append(asyncio.ensure_future(self.post(session, url, data[x] if data else None)))
261
+ else:
262
+ tasks.append(asyncio.ensure_future(self.get(session, url)))
263
+
264
+ return await asyncio.gather(*tasks)
265
+
266
+ async def get(self, session, url):
267
+ """
268
+ Runs an async HTTP GET request.
269
+
270
+ Args:
271
+ session: ClientSession
272
+ url: url
273
+
274
+ Returns:
275
+ json results if any
276
+ """
277
+
278
+ async with session.get(url) as resp:
279
+ return await resp.json()
280
+
281
+ async def post(self, session, url, data):
282
+ """
283
+ Runs an async HTTP POST request.
284
+
285
+ Args:
286
+ session: ClientSession
287
+ url: url
288
+ data: data to POST
289
+
290
+ Returns:
291
+ json results if any
292
+ """
293
+
294
+ async with session.post(url, json=data) as resp:
295
+ return await resp.json()
txtai/api/extension.py ADDED
@@ -0,0 +1,19 @@
1
+ """
2
+ Extension module
3
+ """
4
+
5
+
6
+ class Extension:
7
+ """
8
+ Defines an API extension. API extensions can expose custom pipelines or other custom logic.
9
+ """
10
+
11
+ def __call__(self, app):
12
+ """
13
+ Hook to register custom routing logic and/or modify the FastAPI instance.
14
+
15
+ Args:
16
+ app: FastAPI application instance
17
+ """
18
+
19
+ return
txtai/api/factory.py ADDED
@@ -0,0 +1,40 @@
1
+ """
2
+ API factory module
3
+ """
4
+
5
+ from ..util import Resolver
6
+
7
+
8
+ class APIFactory:
9
+ """
10
+ API factory. Creates new API instances.
11
+ """
12
+
13
+ @staticmethod
14
+ def get(api):
15
+ """
16
+ Gets a new instance of api class.
17
+
18
+ Args:
19
+ api: API instance class
20
+
21
+ Returns:
22
+ API
23
+ """
24
+
25
+ return Resolver()(api)
26
+
27
+ @staticmethod
28
+ def create(config, api):
29
+ """
30
+ Creates a new API instance.
31
+
32
+ Args:
33
+ config: API configuration
34
+ api: API instance class
35
+
36
+ Returns:
37
+ API instance
38
+ """
39
+
40
+ return APIFactory.get(api)(config)
@@ -0,0 +1,7 @@
1
+ """
2
+ Responses imports
3
+ """
4
+
5
+ from .factory import ResponseFactory
6
+ from .json import JSONEncoder, JSONResponse
7
+ from .messagepack import MessagePackResponse
@@ -0,0 +1,30 @@
1
+ """
2
+ Factory module
3
+ """
4
+
5
+ from .json import JSONResponse
6
+ from .messagepack import MessagePackResponse
7
+
8
+
9
+ class ResponseFactory:
10
+ """
11
+ Methods to create Response classes.
12
+ """
13
+
14
+ @staticmethod
15
+ def create(request):
16
+ """
17
+ Gets a response class for request using the Accept header.
18
+
19
+ Args:
20
+ request: request
21
+
22
+ Returns:
23
+ response class
24
+ """
25
+
26
+ # Get Accept header
27
+ accept = request.headers.get("Accept")
28
+
29
+ # Get response class
30
+ return MessagePackResponse if accept == MessagePackResponse.media_type else JSONResponse
@@ -0,0 +1,56 @@
1
+ """
2
+ JSON module
3
+ """
4
+
5
+ import base64
6
+ import json
7
+
8
+ from io import BytesIO
9
+ from typing import Any
10
+
11
+ import fastapi.responses
12
+
13
+ from PIL.Image import Image
14
+
15
+
16
+ class JSONEncoder(json.JSONEncoder):
17
+ """
18
+ Extended JSONEncoder that serializes images and byte streams as base64 encoded text.
19
+ """
20
+
21
+ def default(self, o):
22
+ # Convert Image to BytesIO
23
+ if isinstance(o, Image):
24
+ buffered = BytesIO()
25
+ o.save(buffered, format=o.format, quality="keep")
26
+ o = buffered
27
+
28
+ # Unpack bytes from BytesIO
29
+ if isinstance(o, BytesIO):
30
+ o = o.getvalue()
31
+
32
+ # Base64 encode bytes instances
33
+ if isinstance(o, bytes):
34
+ return base64.b64encode(o).decode("utf-8")
35
+
36
+ # Default handler
37
+ return super().default(o)
38
+
39
+
40
+ class JSONResponse(fastapi.responses.JSONResponse):
41
+ """
42
+ Extended JSONResponse that serializes images and byte streams as base64 encoded text.
43
+ """
44
+
45
+ def render(self, content: Any) -> bytes:
46
+ """
47
+ Renders content to the response.
48
+
49
+ Args:
50
+ content: input content
51
+
52
+ Returns:
53
+ rendered content as bytes
54
+ """
55
+
56
+ return json.dumps(content, ensure_ascii=False, allow_nan=False, indent=None, separators=(",", ":"), cls=JSONEncoder).encode("utf-8")
@@ -0,0 +1,51 @@
1
+ """
2
+ MessagePack module
3
+ """
4
+
5
+ from io import BytesIO
6
+ from typing import Any
7
+
8
+ import msgpack
9
+
10
+ from fastapi import Response
11
+ from PIL.Image import Image
12
+
13
+
14
+ class MessagePackResponse(Response):
15
+ """
16
+ Encodes responses with MessagePack.
17
+ """
18
+
19
+ media_type = "application/msgpack"
20
+
21
+ def render(self, content: Any) -> bytes:
22
+ """
23
+ Renders content to the response.
24
+
25
+ Args:
26
+ content: input content
27
+
28
+ Returns:
29
+ rendered content as bytes
30
+ """
31
+
32
+ return msgpack.packb(content, default=MessagePackEncoder())
33
+
34
+
35
+ class MessagePackEncoder:
36
+ """
37
+ Extended MessagePack encoder that converts images to bytes.
38
+ """
39
+
40
+ def __call__(self, o):
41
+ # Convert Image to bytes
42
+ if isinstance(o, Image):
43
+ buffered = BytesIO()
44
+ o.save(buffered, format=o.format, quality="keep")
45
+ o = buffered
46
+
47
+ # Get bytes from BytesIO
48
+ if isinstance(o, BytesIO):
49
+ o = o.getvalue()
50
+
51
+ return o
txtai/api/route.py ADDED
@@ -0,0 +1,41 @@
1
+ """
2
+ Route module
3
+ """
4
+
5
+ from fastapi.routing import APIRoute, get_request_handler
6
+
7
+ from .responses import ResponseFactory
8
+
9
+
10
+ class EncodingAPIRoute(APIRoute):
11
+ """
12
+ Extended APIRoute that encodes responses based on HTTP Accept header.
13
+ """
14
+
15
+ def get_route_handler(self):
16
+ """
17
+ Resolves a response class based on the HTTP Accept header.
18
+
19
+ Returns:
20
+ route handler function
21
+ """
22
+
23
+ async def handler(request):
24
+ route = get_request_handler(
25
+ dependant=self.dependant,
26
+ body_field=self.body_field,
27
+ status_code=self.status_code,
28
+ response_class=ResponseFactory.create(request),
29
+ response_field=self.secure_cloned_response_field,
30
+ response_model_include=self.response_model_include,
31
+ response_model_exclude=self.response_model_exclude,
32
+ response_model_by_alias=self.response_model_by_alias,
33
+ response_model_exclude_unset=self.response_model_exclude_unset,
34
+ response_model_exclude_defaults=self.response_model_exclude_defaults,
35
+ response_model_exclude_none=self.response_model_exclude_none,
36
+ dependency_overrides_provider=self.dependency_overrides_provider,
37
+ )
38
+
39
+ return await route(request)
40
+
41
+ return handler