elasticsearch 8.12.1__py3-none-any.whl → 8.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. elasticsearch/__init__.py +7 -0
  2. elasticsearch/_async/client/__init__.py +477 -128
  3. elasticsearch/_async/client/_base.py +41 -1
  4. elasticsearch/_async/client/async_search.py +40 -12
  5. elasticsearch/_async/client/autoscaling.py +37 -11
  6. elasticsearch/_async/client/cat.py +260 -69
  7. elasticsearch/_async/client/ccr.py +123 -38
  8. elasticsearch/_async/client/cluster.py +153 -42
  9. elasticsearch/_async/client/dangling_indices.py +27 -8
  10. elasticsearch/_async/client/enrich.py +48 -14
  11. elasticsearch/_async/client/eql.py +38 -12
  12. elasticsearch/_async/client/esql.py +10 -2
  13. elasticsearch/_async/client/features.py +17 -4
  14. elasticsearch/_async/client/fleet.py +30 -7
  15. elasticsearch/_async/client/graph.py +11 -3
  16. elasticsearch/_async/client/ilm.py +101 -29
  17. elasticsearch/_async/client/indices.py +688 -181
  18. elasticsearch/_async/client/inference.py +111 -44
  19. elasticsearch/_async/client/ingest.py +59 -16
  20. elasticsearch/_async/client/license.py +58 -14
  21. elasticsearch/_async/client/logstash.py +31 -9
  22. elasticsearch/_async/client/migration.py +28 -7
  23. elasticsearch/_async/client/ml.py +781 -214
  24. elasticsearch/_async/client/monitoring.py +10 -2
  25. elasticsearch/_async/client/nodes.py +103 -29
  26. elasticsearch/_async/client/query_ruleset.py +37 -11
  27. elasticsearch/_async/client/rollup.py +79 -24
  28. elasticsearch/_async/client/search_application.py +76 -23
  29. elasticsearch/_async/client/searchable_snapshots.py +49 -12
  30. elasticsearch/_async/client/security.py +544 -143
  31. elasticsearch/_async/client/shutdown.py +28 -6
  32. elasticsearch/_async/client/slm.py +80 -22
  33. elasticsearch/_async/client/snapshot.py +140 -54
  34. elasticsearch/_async/client/sql.py +55 -15
  35. elasticsearch/_async/client/ssl.py +9 -2
  36. elasticsearch/_async/client/synonyms.py +75 -21
  37. elasticsearch/_async/client/tasks.py +29 -8
  38. elasticsearch/_async/client/text_structure.py +74 -2
  39. elasticsearch/_async/client/transform.py +106 -32
  40. elasticsearch/_async/client/watcher.py +110 -31
  41. elasticsearch/_async/client/xpack.py +16 -4
  42. elasticsearch/_async/helpers.py +1 -1
  43. elasticsearch/_otel.py +92 -0
  44. elasticsearch/_sync/client/__init__.py +477 -128
  45. elasticsearch/_sync/client/_base.py +41 -1
  46. elasticsearch/_sync/client/async_search.py +40 -12
  47. elasticsearch/_sync/client/autoscaling.py +37 -11
  48. elasticsearch/_sync/client/cat.py +260 -69
  49. elasticsearch/_sync/client/ccr.py +123 -38
  50. elasticsearch/_sync/client/cluster.py +153 -42
  51. elasticsearch/_sync/client/dangling_indices.py +27 -8
  52. elasticsearch/_sync/client/enrich.py +48 -14
  53. elasticsearch/_sync/client/eql.py +38 -12
  54. elasticsearch/_sync/client/esql.py +10 -2
  55. elasticsearch/_sync/client/features.py +17 -4
  56. elasticsearch/_sync/client/fleet.py +30 -7
  57. elasticsearch/_sync/client/graph.py +11 -3
  58. elasticsearch/_sync/client/ilm.py +101 -29
  59. elasticsearch/_sync/client/indices.py +688 -181
  60. elasticsearch/_sync/client/inference.py +111 -44
  61. elasticsearch/_sync/client/ingest.py +59 -16
  62. elasticsearch/_sync/client/license.py +58 -14
  63. elasticsearch/_sync/client/logstash.py +31 -9
  64. elasticsearch/_sync/client/migration.py +28 -7
  65. elasticsearch/_sync/client/ml.py +781 -214
  66. elasticsearch/_sync/client/monitoring.py +10 -2
  67. elasticsearch/_sync/client/nodes.py +103 -29
  68. elasticsearch/_sync/client/query_ruleset.py +37 -11
  69. elasticsearch/_sync/client/rollup.py +79 -24
  70. elasticsearch/_sync/client/search_application.py +76 -23
  71. elasticsearch/_sync/client/searchable_snapshots.py +49 -12
  72. elasticsearch/_sync/client/security.py +544 -143
  73. elasticsearch/_sync/client/shutdown.py +28 -6
  74. elasticsearch/_sync/client/slm.py +80 -22
  75. elasticsearch/_sync/client/snapshot.py +140 -54
  76. elasticsearch/_sync/client/sql.py +55 -15
  77. elasticsearch/_sync/client/ssl.py +9 -2
  78. elasticsearch/_sync/client/synonyms.py +75 -21
  79. elasticsearch/_sync/client/tasks.py +29 -8
  80. elasticsearch/_sync/client/text_structure.py +74 -2
  81. elasticsearch/_sync/client/transform.py +106 -32
  82. elasticsearch/_sync/client/watcher.py +110 -31
  83. elasticsearch/_sync/client/xpack.py +16 -4
  84. elasticsearch/_version.py +1 -1
  85. elasticsearch/helpers/actions.py +1 -1
  86. elasticsearch/helpers/vectorstore/__init__.py +62 -0
  87. elasticsearch/helpers/vectorstore/_async/__init__.py +16 -0
  88. elasticsearch/helpers/vectorstore/_async/_utils.py +39 -0
  89. elasticsearch/helpers/vectorstore/_async/embedding_service.py +89 -0
  90. elasticsearch/helpers/vectorstore/_async/strategies.py +466 -0
  91. elasticsearch/helpers/vectorstore/_async/vectorstore.py +391 -0
  92. elasticsearch/helpers/vectorstore/_sync/__init__.py +16 -0
  93. elasticsearch/helpers/vectorstore/_sync/_utils.py +39 -0
  94. elasticsearch/helpers/vectorstore/_sync/embedding_service.py +89 -0
  95. elasticsearch/helpers/vectorstore/_sync/strategies.py +466 -0
  96. elasticsearch/helpers/vectorstore/_sync/vectorstore.py +388 -0
  97. elasticsearch/helpers/vectorstore/_utils.py +116 -0
  98. elasticsearch/serializer.py +14 -0
  99. {elasticsearch-8.12.1.dist-info → elasticsearch-8.13.1.dist-info}/METADATA +28 -8
  100. elasticsearch-8.13.1.dist-info/RECORD +116 -0
  101. {elasticsearch-8.12.1.dist-info → elasticsearch-8.13.1.dist-info}/WHEEL +1 -1
  102. elasticsearch-8.12.1.dist-info/RECORD +0 -103
  103. {elasticsearch-8.12.1.dist-info → elasticsearch-8.13.1.dist-info}/LICENSE +0 -0
  104. {elasticsearch-8.12.1.dist-info → elasticsearch-8.13.1.dist-info}/NOTICE +0 -0
  105. {elasticsearch-8.12.1.dist-info → elasticsearch-8.13.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,391 @@
1
+ # Licensed to Elasticsearch B.V. under one or more contributor
2
+ # license agreements. See the NOTICE file distributed with
3
+ # this work for additional information regarding copyright
4
+ # ownership. Elasticsearch B.V. licenses this file to you under
5
+ # the Apache License, Version 2.0 (the "License"); you may
6
+ # not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ import logging
19
+ import uuid
20
+ from typing import Any, Callable, Dict, List, Optional
21
+
22
+ from elasticsearch import AsyncElasticsearch
23
+ from elasticsearch._version import __versionstr__ as lib_version
24
+ from elasticsearch.helpers import BulkIndexError, async_bulk
25
+ from elasticsearch.helpers.vectorstore import (
26
+ AsyncEmbeddingService,
27
+ AsyncRetrievalStrategy,
28
+ )
29
+ from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class AsyncVectorStore:
35
+ """
36
+ VectorStore is a higher-level abstraction of indexing and search.
37
+ Users can pick from available retrieval strategies.
38
+
39
+ Documents have up to 3 fields:
40
+ - text_field: the text to be indexed and searched.
41
+ - metadata: additional information about the document, either schema-free
42
+ or defined by the supplied metadata_mappings.
43
+ - vector_field (usually not filled by the user): the embedding vector of the text.
44
+
45
+ Depending on the strategy, vector embeddings are
46
+ - created by the user beforehand
47
+ - created by this AsyncVectorStore class in Python
48
+ - created in-stack by inference pipelines.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ client: AsyncElasticsearch,
54
+ *,
55
+ index: str,
56
+ retrieval_strategy: AsyncRetrievalStrategy,
57
+ embedding_service: Optional[AsyncEmbeddingService] = None,
58
+ num_dimensions: Optional[int] = None,
59
+ text_field: str = "text_field",
60
+ vector_field: str = "vector_field",
61
+ metadata_mappings: Optional[Dict[str, Any]] = None,
62
+ user_agent: str = f"elasticsearch-py-vs/{lib_version}",
63
+ ) -> None:
64
+ """
65
+ :param user_header: user agent header specific to the 3rd party integration.
66
+ Used for usage tracking in Elastic Cloud.
67
+ :param index: The name of the index to query.
68
+ :param retrieval_strategy: how to index and search the data. See the strategies
69
+ module for availble strategies.
70
+ :param text_field: Name of the field with the textual data.
71
+ :param vector_field: For strategies that perform embedding inference in Python,
72
+ the embedding vector goes in this field.
73
+ :param client: Elasticsearch client connection. Alternatively specify the
74
+ Elasticsearch connection with the other es_* parameters.
75
+ """
76
+ # Add integration-specific usage header for tracking usage in Elastic Cloud.
77
+ # client.options preserves existing (non-user-agent) headers.
78
+ client = client.options(headers={"User-Agent": user_agent})
79
+
80
+ if hasattr(retrieval_strategy, "text_field"):
81
+ retrieval_strategy.text_field = text_field
82
+ if hasattr(retrieval_strategy, "vector_field"):
83
+ retrieval_strategy.vector_field = vector_field
84
+
85
+ self.client = client
86
+ self.index = index
87
+ self.retrieval_strategy = retrieval_strategy
88
+ self.embedding_service = embedding_service
89
+ self.num_dimensions = num_dimensions
90
+ self.text_field = text_field
91
+ self.vector_field = vector_field
92
+ self.metadata_mappings = metadata_mappings
93
+
94
+ async def close(self) -> None:
95
+ return await self.client.close()
96
+
97
+ async def add_texts(
98
+ self,
99
+ texts: List[str],
100
+ *,
101
+ metadatas: Optional[List[Dict[str, Any]]] = None,
102
+ vectors: Optional[List[List[float]]] = None,
103
+ ids: Optional[List[str]] = None,
104
+ refresh_indices: bool = True,
105
+ create_index_if_not_exists: bool = True,
106
+ bulk_kwargs: Optional[Dict[str, Any]] = None,
107
+ ) -> List[str]:
108
+ """Add documents to the Elasticsearch index.
109
+
110
+ :param texts: List of text documents.
111
+ :param metadata: Optional list of document metadata. Must be of same length as
112
+ texts.
113
+ :param vectors: Optional list of embedding vectors. Must be of same length as
114
+ texts.
115
+ :param ids: Optional list of ID strings. Must be of same length as texts.
116
+ :param refresh_indices: Whether to refresh the index after deleting documents.
117
+ Defaults to True.
118
+ :param create_index_if_not_exists: Whether to create the index if it does not
119
+ exist. Defaults to True.
120
+ :param bulk_kwargs: Arguments to pass to the bulk function when indexing
121
+ (for example chunk_size).
122
+
123
+ :return: List of IDs of the created documents, either echoing the provided one
124
+ or returning newly created ones.
125
+ """
126
+ bulk_kwargs = bulk_kwargs or {}
127
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
128
+ requests = []
129
+
130
+ if create_index_if_not_exists:
131
+ await self._create_index_if_not_exists()
132
+
133
+ if self.embedding_service and not vectors:
134
+ vectors = await self.embedding_service.embed_documents(texts)
135
+
136
+ for i, text in enumerate(texts):
137
+ metadata = metadatas[i] if metadatas else {}
138
+
139
+ request: Dict[str, Any] = {
140
+ "_op_type": "index",
141
+ "_index": self.index,
142
+ self.text_field: text,
143
+ "metadata": metadata,
144
+ "_id": ids[i],
145
+ }
146
+
147
+ if vectors:
148
+ request[self.vector_field] = vectors[i]
149
+
150
+ requests.append(request)
151
+
152
+ if len(requests) > 0:
153
+ try:
154
+ success, failed = await async_bulk(
155
+ self.client,
156
+ requests,
157
+ stats_only=True,
158
+ refresh=refresh_indices,
159
+ **bulk_kwargs,
160
+ )
161
+ logger.debug(f"added texts {ids} to index")
162
+ return ids
163
+ except BulkIndexError as e:
164
+ logger.error(f"Error adding texts: {e}")
165
+ firstError = e.errors[0].get("index", {}).get("error", {})
166
+ logger.error(f"First error reason: {firstError.get('reason')}")
167
+ raise e
168
+
169
+ else:
170
+ logger.debug("No texts to add to index")
171
+ return []
172
+
173
+ async def delete( # type: ignore[no-untyped-def]
174
+ self,
175
+ *,
176
+ ids: Optional[List[str]] = None,
177
+ query: Optional[Dict[str, Any]] = None,
178
+ refresh_indices: bool = True,
179
+ **delete_kwargs,
180
+ ) -> bool:
181
+ """Delete documents from the Elasticsearch index.
182
+
183
+ :param ids: List of IDs of documents to delete.
184
+ :param refresh_indices: Whether to refresh the index after deleting documents.
185
+ Defaults to True.
186
+
187
+ :return: True if deletion was successful.
188
+ """
189
+ if ids is not None and query is not None:
190
+ raise ValueError("one of ids or query must be specified")
191
+ elif ids is None and query is None:
192
+ raise ValueError("either specify ids or query")
193
+
194
+ try:
195
+ if ids:
196
+ body = [
197
+ {"_op_type": "delete", "_index": self.index, "_id": _id}
198
+ for _id in ids
199
+ ]
200
+ await async_bulk(
201
+ self.client,
202
+ body,
203
+ refresh=refresh_indices,
204
+ ignore_status=404,
205
+ **delete_kwargs,
206
+ )
207
+ logger.debug(f"Deleted {len(body)} texts from index")
208
+
209
+ else:
210
+ await self.client.delete_by_query(
211
+ index=self.index,
212
+ query=query,
213
+ refresh=refresh_indices,
214
+ **delete_kwargs,
215
+ )
216
+
217
+ except BulkIndexError as e:
218
+ logger.error(f"Error deleting texts: {e}")
219
+ firstError = e.errors[0].get("index", {}).get("error", {})
220
+ logger.error(f"First error reason: {firstError.get('reason')}")
221
+ raise e
222
+
223
+ return True
224
+
225
+ async def search(
226
+ self,
227
+ *,
228
+ query: Optional[str],
229
+ query_vector: Optional[List[float]] = None,
230
+ k: int = 4,
231
+ num_candidates: int = 50,
232
+ fields: Optional[List[str]] = None,
233
+ filter: Optional[List[Dict[str, Any]]] = None,
234
+ custom_query: Optional[
235
+ Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]]
236
+ ] = None,
237
+ ) -> List[Dict[str, Any]]:
238
+ """
239
+ :param query: Input query string.
240
+ :param query_vector: Input embedding vector. If given, input query string is
241
+ ignored.
242
+ :param k: Number of returned results.
243
+ :param num_candidates: Number of candidates to fetch from data nodes in knn.
244
+ :param fields: List of field names to return.
245
+ :param filter: Elasticsearch filters to apply.
246
+ :param custom_query: Function to modify the Elasticsearch query body before it is
247
+ sent to Elasticsearch.
248
+
249
+ :return: List of document hits. Includes _index, _id, _score and _source.
250
+ """
251
+ if fields is None:
252
+ fields = []
253
+ if "metadata" not in fields:
254
+ fields.append("metadata")
255
+ if self.text_field not in fields:
256
+ fields.append(self.text_field)
257
+
258
+ if self.embedding_service and not query_vector:
259
+ if not query:
260
+ raise ValueError("specify a query or a query_vector to search")
261
+ query_vector = await self.embedding_service.embed_query(query)
262
+
263
+ query_body = self.retrieval_strategy.es_query(
264
+ query=query,
265
+ query_vector=query_vector,
266
+ text_field=self.text_field,
267
+ vector_field=self.vector_field,
268
+ k=k,
269
+ num_candidates=num_candidates,
270
+ filter=filter or [],
271
+ )
272
+
273
+ if custom_query is not None:
274
+ query_body = custom_query(query_body, query)
275
+ logger.debug(f"Calling custom_query, Query body now: {query_body}")
276
+
277
+ response = await self.client.search(
278
+ index=self.index,
279
+ **query_body,
280
+ size=k,
281
+ source=True,
282
+ source_includes=fields,
283
+ )
284
+ hits: List[Dict[str, Any]] = response["hits"]["hits"]
285
+
286
+ return hits
287
+
288
+ async def _create_index_if_not_exists(self) -> None:
289
+ exists = await self.client.indices.exists(index=self.index)
290
+ if exists.meta.status == 200:
291
+ logger.debug(f"Index {self.index} already exists. Skipping creation.")
292
+ return
293
+
294
+ if self.retrieval_strategy.needs_inference():
295
+ if not self.num_dimensions and not self.embedding_service:
296
+ raise ValueError(
297
+ "retrieval strategy requires embeddings; either embedding_service "
298
+ "or num_dimensions need to be specified"
299
+ )
300
+ if not self.num_dimensions and self.embedding_service:
301
+ vector = await self.embedding_service.embed_query("get num dimensions")
302
+ self.num_dimensions = len(vector)
303
+
304
+ mappings, settings = self.retrieval_strategy.es_mappings_settings(
305
+ text_field=self.text_field,
306
+ vector_field=self.vector_field,
307
+ num_dimensions=self.num_dimensions,
308
+ )
309
+ if self.metadata_mappings:
310
+ metadata = mappings["properties"].get("metadata", {"properties": {}})
311
+ for key in self.metadata_mappings.keys():
312
+ if key in metadata:
313
+ raise ValueError(f"metadata key {key} already exists in mappings")
314
+
315
+ metadata = dict(**metadata["properties"], **self.metadata_mappings)
316
+ mappings["properties"]["metadata"] = {"properties": metadata}
317
+
318
+ await self.retrieval_strategy.before_index_creation(
319
+ client=self.client,
320
+ text_field=self.text_field,
321
+ vector_field=self.vector_field,
322
+ )
323
+ await self.client.indices.create(
324
+ index=self.index, mappings=mappings, settings=settings
325
+ )
326
+
327
+ async def max_marginal_relevance_search(
328
+ self,
329
+ *,
330
+ embedding_service: AsyncEmbeddingService,
331
+ query: str,
332
+ vector_field: str,
333
+ k: int = 4,
334
+ num_candidates: int = 20,
335
+ lambda_mult: float = 0.5,
336
+ fields: Optional[List[str]] = None,
337
+ custom_query: Optional[
338
+ Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]]
339
+ ] = None,
340
+ ) -> List[Dict[str, Any]]:
341
+ """Return docs selected using the maximal marginal relevance.
342
+
343
+ Maximal marginal relevance optimizes for similarity to query AND diversity
344
+ among selected documents.
345
+
346
+ :param query (str): Text to look up documents similar to.
347
+ :param k (int): Number of Documents to return. Defaults to 4.
348
+ :param fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
349
+ :param lambda_mult (float): Number between 0 and 1 that determines the degree
350
+ of diversity among the results with 0 corresponding
351
+ to maximum diversity and 1 to minimum diversity.
352
+ Defaults to 0.5.
353
+ :param fields: Other fields to get from elasticsearch source. These fields
354
+ will be added to the document metadata.
355
+
356
+ :return: A list of Documents selected by maximal marginal relevance.
357
+ """
358
+ remove_vector_query_field_from_metadata = True
359
+ if fields is None:
360
+ fields = [vector_field]
361
+ elif vector_field not in fields:
362
+ fields.append(vector_field)
363
+ else:
364
+ remove_vector_query_field_from_metadata = False
365
+
366
+ # Embed the query
367
+ query_embedding = await embedding_service.embed_query(query)
368
+
369
+ # Fetch the initial documents
370
+ got_hits = await self.search(
371
+ query=None,
372
+ query_vector=query_embedding,
373
+ k=num_candidates,
374
+ fields=fields,
375
+ custom_query=custom_query,
376
+ )
377
+
378
+ # Get the embeddings for the fetched documents
379
+ got_embeddings = [hit["_source"][vector_field] for hit in got_hits]
380
+
381
+ # Select documents using maximal marginal relevance
382
+ selected_indices = maximal_marginal_relevance(
383
+ query_embedding, got_embeddings, lambda_mult=lambda_mult, k=k
384
+ )
385
+ selected_hits = [got_hits[i] for i in selected_indices]
386
+
387
+ if remove_vector_query_field_from_metadata:
388
+ for hit in selected_hits:
389
+ del hit["_source"][vector_field]
390
+
391
+ return selected_hits
@@ -0,0 +1,16 @@
1
+ # Licensed to Elasticsearch B.V. under one or more contributor
2
+ # license agreements. See the NOTICE file distributed with
3
+ # this work for additional information regarding copyright
4
+ # ownership. Elasticsearch B.V. licenses this file to you under
5
+ # the Apache License, Version 2.0 (the "License"); you may
6
+ # not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
@@ -0,0 +1,39 @@
1
+ # Licensed to Elasticsearch B.V. under one or more contributor
2
+ # license agreements. See the NOTICE file distributed with
3
+ # this work for additional information regarding copyright
4
+ # ownership. Elasticsearch B.V. licenses this file to you under
5
+ # the Apache License, Version 2.0 (the "License"); you may
6
+ # not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from elasticsearch import BadRequestError, Elasticsearch, NotFoundError
19
+
20
+
21
+ def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None:
22
+ """
23
+ :raises [NotFoundError]: if the model is neither downloaded nor deployed.
24
+ :raises [ConflictError]: if the model is downloaded but not yet deployed.
25
+ """
26
+ doc = {"text_field": f"test if the model '{model_id}' is deployed"}
27
+ try:
28
+ client.ml.infer_trained_model(model_id=model_id, docs=[doc])
29
+ except BadRequestError:
30
+ # The model is deployed but expects a different input field name.
31
+ pass
32
+
33
+
34
+ def model_is_deployed(client: Elasticsearch, model_id: str) -> bool:
35
+ try:
36
+ model_must_be_deployed(client, model_id)
37
+ return True
38
+ except NotFoundError:
39
+ return False
@@ -0,0 +1,89 @@
1
+ # Licensed to Elasticsearch B.V. under one or more contributor
2
+ # license agreements. See the NOTICE file distributed with
3
+ # this work for additional information regarding copyright
4
+ # ownership. Elasticsearch B.V. licenses this file to you under
5
+ # the Apache License, Version 2.0 (the "License"); you may
6
+ # not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ from abc import ABC, abstractmethod
19
+ from typing import List
20
+
21
+ from elasticsearch import Elasticsearch
22
+ from elasticsearch._version import __versionstr__ as lib_version
23
+
24
+
25
+ class EmbeddingService(ABC):
26
+ @abstractmethod
27
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
28
+ """Generate embeddings for a list of documents.
29
+
30
+ :param texts: A list of document strings to generate embeddings for.
31
+
32
+ :return: A list of embeddings, one for each document in the input.
33
+ """
34
+
35
+ @abstractmethod
36
+ def embed_query(self, query: str) -> List[float]:
37
+ """Generate an embedding for a single query text.
38
+
39
+ :param text: The query text to generate an embedding for.
40
+
41
+ :return: The embedding for the input query text.
42
+ """
43
+
44
+
45
+ class ElasticsearchEmbeddings(EmbeddingService):
46
+ """Elasticsearch as a service for embedding model inference.
47
+
48
+ You need to have an embedding model downloaded and deployed in Elasticsearch:
49
+ - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
50
+ - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
51
+ """ # noqa: E501
52
+
53
+ def __init__(
54
+ self,
55
+ *,
56
+ client: Elasticsearch,
57
+ model_id: str,
58
+ input_field: str = "text_field",
59
+ user_agent: str = f"elasticsearch-py-es/{lib_version}",
60
+ ):
61
+ """
62
+ :param agent_header: user agent header specific to the 3rd party integration.
63
+ Used for usage tracking in Elastic Cloud.
64
+ :param model_id: The model_id of the model deployed in the Elasticsearch cluster.
65
+ :param input_field: The name of the key for the input text field in the
66
+ document. Defaults to 'text_field'.
67
+ :param client: Elasticsearch client connection. Alternatively specify the
68
+ Elasticsearch connection with the other es_* parameters.
69
+ """
70
+ # Add integration-specific usage header for tracking usage in Elastic Cloud.
71
+ # client.options preserves existing (non-user-agent) headers.
72
+ client = client.options(headers={"User-Agent": user_agent})
73
+
74
+ self.client = client
75
+ self.model_id = model_id
76
+ self.input_field = input_field
77
+
78
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
79
+ return self._embedding_func(texts)
80
+
81
+ def embed_query(self, text: str) -> List[float]:
82
+ result = self._embedding_func([text])
83
+ return result[0]
84
+
85
+ def _embedding_func(self, texts: List[str]) -> List[List[float]]:
86
+ response = self.client.ml.infer_trained_model(
87
+ model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
88
+ )
89
+ return [doc["predicted_value"] for doc in response["inference_results"]]