llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  57. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  58. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  59. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  60. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  61. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  62. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  63. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  64. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  65. llama_stack/providers/registry/agents.py +1 -0
  66. llama_stack/providers/registry/inference.py +1 -9
  67. llama_stack/providers/registry/vector_io.py +136 -16
  68. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  69. llama_stack/providers/remote/files/s3/config.py +5 -3
  70. llama_stack/providers/remote/files/s3/files.py +2 -2
  71. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  72. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  73. llama_stack/providers/remote/inference/together/together.py +4 -0
  74. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  75. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  76. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  77. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  78. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  79. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  80. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  81. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  82. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  83. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  84. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  85. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  86. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  87. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  88. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  89. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  90. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  91. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  92. llama_stack/providers/utils/bedrock/client.py +3 -3
  93. llama_stack/providers/utils/bedrock/config.py +7 -7
  94. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  95. llama_stack/providers/utils/inference/http_client.py +239 -0
  96. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  97. llama_stack/providers/utils/inference/model_registry.py +148 -2
  98. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  99. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  100. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  101. llama_stack/providers/utils/memory/vector_store.py +46 -19
  102. llama_stack/providers/utils/responses/responses_store.py +7 -7
  103. llama_stack/providers/utils/safety.py +114 -0
  104. llama_stack/providers/utils/tools/mcp.py +44 -3
  105. llama_stack/testing/api_recorder.py +9 -3
  106. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  107. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
  108. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  109. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  110. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  111. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  112. llama_stack/models/llama/hadamard_utils.py +0 -88
  113. llama_stack/models/llama/llama3/args.py +0 -74
  114. llama_stack/models/llama/llama3/dog.jpg +0 -0
  115. llama_stack/models/llama/llama3/generation.py +0 -378
  116. llama_stack/models/llama/llama3/model.py +0 -304
  117. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  118. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  119. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  120. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  121. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  122. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  123. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  124. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  125. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  126. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  127. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  128. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  129. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  130. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  131. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  132. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  133. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  134. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  135. llama_stack/models/llama/llama4/args.py +0 -107
  136. llama_stack/models/llama/llama4/ffn.py +0 -58
  137. llama_stack/models/llama/llama4/moe.py +0 -214
  138. llama_stack/models/llama/llama4/preprocess.py +0 -435
  139. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  140. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  141. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  142. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  143. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  144. llama_stack/models/llama/quantize_impls.py +0 -316
  145. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  146. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  147. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  148. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  149. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  150. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  151. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  152. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  153. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  154. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  155. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -4,15 +4,15 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- from typing import Any
8
-
9
7
  import litellm
10
8
  import requests
11
9
 
12
10
  from llama_stack.core.request_headers import NeedsRequestProviderData
13
11
  from llama_stack.log import get_logger
12
+ from llama_stack.providers.utils.safety import ShieldToModerationMixin
14
13
  from llama_stack_api import (
15
- OpenAIMessageParam,
14
+ GetShieldRequest,
15
+ RunShieldRequest,
16
16
  RunShieldResponse,
17
17
  Safety,
18
18
  SafetyViolation,
@@ -28,7 +28,7 @@ logger = get_logger(name=__name__, category="safety::sambanova")
28
28
  CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
29
29
 
30
30
 
31
- class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
31
+ class SambaNovaSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
32
32
  def __init__(self, config: SambaNovaSafetyConfig) -> None:
33
33
  self.config = config
34
34
  self.environment_available_models = []
@@ -69,17 +69,19 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
69
69
  async def unregister_shield(self, identifier: str) -> None:
70
70
  pass
71
71
 
72
- async def run_shield(
73
- self, shield_id: str, messages: list[OpenAIMessageParam], params: dict[str, Any] | None = None
74
- ) -> RunShieldResponse:
75
- shield = await self.shield_store.get_shield(shield_id)
72
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
73
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
76
74
  if not shield:
77
- raise ValueError(f"Shield {shield_id} not found")
75
+ raise ValueError(f"Shield {request.shield_id} not found")
78
76
 
79
77
  shield_params = shield.params
80
- logger.debug(f"run_shield::{shield_params}::messages={messages}")
78
+ logger.debug(f"run_shield::{shield_params}::messages={request.messages}")
81
79
 
82
- response = litellm.completion(model=shield.provider_resource_id, messages=messages, api_key=self._get_api_key())
80
+ response = litellm.completion(
81
+ model=shield.provider_resource_id,
82
+ messages=request.messages,
83
+ api_key=self._get_api_key(),
84
+ )
83
85
  shield_message = response.choices[0].message.content
84
86
 
85
87
  if "unsafe" in shield_message.lower():
@@ -0,0 +1,17 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from llama_stack_api import Api, ProviderSpec
8
+
9
+ from .config import ElasticsearchVectorIOConfig
10
+
11
+
12
+ async def get_adapter_impl(config: ElasticsearchVectorIOConfig, deps: dict[Api, ProviderSpec]):
13
+ from .elasticsearch import ElasticsearchVectorIOAdapter
14
+
15
+ impl = ElasticsearchVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
16
+ await impl.initialize()
17
+ return impl
@@ -0,0 +1,32 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Any
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+ from llama_stack.core.storage.datatypes import KVStoreReference
12
+ from llama_stack_api import json_schema_type
13
+
14
+
15
+ @json_schema_type
16
+ class ElasticsearchVectorIOConfig(BaseModel):
17
+ elasticsearch_api_key: str | None = Field(description="The API key for the Elasticsearch instance", default=None)
18
+ elasticsearch_url: str | None = Field(description="The URL of the Elasticsearch instance", default="localhost:9200")
19
+ persistence: KVStoreReference | None = Field(
20
+ description="Config for KV store backend (SQLite only for now)", default=None
21
+ )
22
+
23
+ @classmethod
24
+ def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
25
+ return {
26
+ "elasticsearch_url": "${env.ELASTICSEARCH_URL:=localhost:9200}",
27
+ "elasticsearch_api_key": "${env.ELASTICSEARCH_API_KEY:=}",
28
+ "persistence": KVStoreReference(
29
+ backend="kv_default",
30
+ namespace="vector_io::elasticsearch",
31
+ ).model_dump(exclude_none=True),
32
+ }
@@ -0,0 +1,463 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Any
8
+
9
+ from elasticsearch import ApiError, AsyncElasticsearch
10
+ from elasticsearch.helpers import async_bulk
11
+ from numpy.typing import NDArray
12
+
13
+ from llama_stack.core.storage.kvstore import kvstore_impl
14
+ from llama_stack.log import get_logger
15
+ from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
16
+ from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex
17
+ from llama_stack_api import (
18
+ EmbeddedChunk,
19
+ Files,
20
+ Inference,
21
+ InterleavedContent,
22
+ QueryChunksResponse,
23
+ VectorIO,
24
+ VectorStore,
25
+ VectorStoreNotFoundError,
26
+ VectorStoresProtocolPrivate,
27
+ )
28
+
29
+ from .config import ElasticsearchVectorIOConfig
30
+
31
+ log = get_logger(name=__name__, category="vector_io::elasticsearch")
32
+
33
+ # KV store prefixes for vector databases
34
+ VERSION = "v3"
35
+ VECTOR_DBS_PREFIX = f"vector_stores:elasticsearch:{VERSION}::"
36
+ VECTOR_INDEX_PREFIX = f"vector_index:elasticsearch:{VERSION}::"
37
+ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:elasticsearch:{VERSION}::"
38
+ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:elasticsearch:{VERSION}::"
39
+ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:elasticsearch:{VERSION}::"
40
+
41
+
42
+ class ElasticsearchIndex(EmbeddingIndex):
43
+ def __init__(self, client: AsyncElasticsearch, collection_name: str):
44
+ self.client = client
45
+ self.collection_name = collection_name
46
+
47
+ # Check if the rerank_params contains the following structure:
48
+ # {
49
+ # "retrievers": {
50
+ # "standard": {"weight": 0.7},
51
+ # "knn": {"weight": 0.3}
52
+ # }
53
+ # }
54
+ async def _is_rerank_linear_param_valid(self, value: dict) -> bool:
55
+ """Validate linear reranker parameters structure."""
56
+ try:
57
+ retrievers = value.get("retrievers", {})
58
+ return (
59
+ isinstance(retrievers.get("standard"), dict)
60
+ and isinstance(retrievers.get("knn"), dict)
61
+ and "weight" in retrievers["standard"]
62
+ and "weight" in retrievers["knn"]
63
+ )
64
+ except (AttributeError, TypeError):
65
+ return False
66
+
67
+ def _convert_to_linear_params(self, reranker_params: dict[str, Any]) -> dict[str, Any] | None:
68
+ weights = reranker_params.get("weights")
69
+ alpha = reranker_params.get("alpha")
70
+ if weights is not None:
71
+ vector_weight = weights.get("vector")
72
+ keyword_weight = weights.get("keyword")
73
+ if vector_weight is None or keyword_weight is None:
74
+ log.warning("Elasticsearch linear retriever requires 'vector' and 'keyword' weights; ignoring weights.")
75
+ return None
76
+ total = vector_weight + keyword_weight
77
+ if total == 0:
78
+ log.warning(
79
+ "Elasticsearch linear retriever weights for 'vector' and 'keyword' sum to 0; ignoring weights."
80
+ )
81
+ return None
82
+ if abs(total - 1.0) > 0.001:
83
+ log.warning(
84
+ "Elasticsearch linear retriever uses normalized vector/keyword weights; "
85
+ "renormalizing provided weights."
86
+ )
87
+ vector_weight /= total
88
+ keyword_weight /= total
89
+ elif alpha is not None:
90
+ vector_weight = alpha
91
+ keyword_weight = 1 - alpha
92
+ else:
93
+ return None
94
+
95
+ return {
96
+ "retrievers": {
97
+ "standard": {"weight": keyword_weight},
98
+ "knn": {"weight": vector_weight},
99
+ }
100
+ }
101
+
102
+ async def initialize(self) -> None:
103
+ # Elasticsearch collections (indexes) are created on-demand in add_chunks
104
+ # If the index does not exist, it will be created in add_chunks.
105
+ pass
106
+
107
+ async def add_chunks(self, chunks: list[EmbeddedChunk]):
108
+ """Adds chunks to the Elasticsearch index."""
109
+ if not chunks:
110
+ return
111
+
112
+ try:
113
+ await self.client.indices.create(
114
+ index=self.collection_name,
115
+ body={
116
+ "mappings": {
117
+ "properties": {
118
+ "content": {"type": "text"},
119
+ "chunk_id": {"type": "keyword"},
120
+ "metadata": {"type": "object"},
121
+ "chunk_metadata": {"type": "object"},
122
+ "embedding": {"type": "dense_vector", "dims": len(chunks[0].embedding)},
123
+ "embedding_dimension": {"type": "integer"},
124
+ "embedding_model": {"type": "keyword"},
125
+ }
126
+ }
127
+ },
128
+ )
129
+ except ApiError as e:
130
+ if e.status_code != 400 or "resource_already_exists_exception" not in e.message:
131
+ log.error(f"Error creating Elasticsearch index {self.collection_name}: {e}")
132
+ raise
133
+
134
+ actions = []
135
+ for chunk in chunks:
136
+ actions.append(
137
+ {
138
+ "_op_type": "index",
139
+ "_index": self.collection_name,
140
+ "_id": chunk.chunk_id,
141
+ "_source": chunk.model_dump(
142
+ exclude_none=True,
143
+ include={
144
+ "content",
145
+ "chunk_id",
146
+ "metadata",
147
+ "chunk_metadata",
148
+ "embedding",
149
+ "embedding_dimension",
150
+ "embedding_model",
151
+ },
152
+ ),
153
+ }
154
+ )
155
+
156
+ try:
157
+ successful_count, error_count = await async_bulk(
158
+ client=self.client, actions=actions, timeout="300s", refresh=True, raise_on_error=False, stats_only=True
159
+ )
160
+ if error_count > 0:
161
+ log.warning(
162
+ f"{error_count} out of {len(chunks)} documents failed to upload in Elasticsearch index {self.collection_name}"
163
+ )
164
+
165
+ log.info(f"Successfully added {successful_count} chunks to Elasticsearch index {self.collection_name}")
166
+ except Exception as e:
167
+ log.error(f"Error adding chunks to Elasticsearch index {self.collection_name}: {e}")
168
+ raise
169
+
170
+ async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None:
171
+ """Remove a chunk from the Elasticsearch index."""
172
+
173
+ actions = []
174
+ for chunk in chunks_for_deletion:
175
+ actions.append({"_op_type": "delete", "_index": self.collection_name, "_id": chunk.chunk_id})
176
+
177
+ try:
178
+ successful_count, error_count = await async_bulk(
179
+ client=self.client, actions=actions, timeout="300s", refresh=True, raise_on_error=True, stats_only=True
180
+ )
181
+ if error_count > 0:
182
+ log.warning(
183
+ f"{error_count} out of {len(chunks_for_deletion)} documents failed to be deleted in Elasticsearch index {self.collection_name}"
184
+ )
185
+
186
+ log.info(f"Successfully deleted {successful_count} chunks from Elasticsearch index {self.collection_name}")
187
+ except Exception as e:
188
+ log.error(f"Error deleting chunks from Elasticsearch index {self.collection_name}: {e}")
189
+ raise
190
+
191
+ async def _results_to_chunks(self, results: dict) -> QueryChunksResponse:
192
+ """Convert search results to QueryChunksResponse."""
193
+
194
+ chunks, scores = [], []
195
+ for result in results.get("hits", {}).get("hits", []):
196
+ try:
197
+ source = result.get("_source", {})
198
+ chunk = EmbeddedChunk(
199
+ content=source.get("content"),
200
+ chunk_id=result.get("_id"),
201
+ embedding=source.get("embedding", []),
202
+ embedding_dimension=source.get("embedding_dimension", len(source.get("embedding", []))),
203
+ embedding_model=source.get("embedding_model", "unknown"),
204
+ chunk_metadata=source.get("chunk_metadata", {}),
205
+ metadata=source.get("metadata", {}),
206
+ )
207
+ except Exception:
208
+ log.exception("Failed to parse chunk")
209
+ continue
210
+
211
+ chunks.append(chunk)
212
+ scores.append(result.get("_score"))
213
+
214
+ return QueryChunksResponse(chunks=chunks, scores=scores)
215
+
216
+ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
217
+ """Vector search using kNN."""
218
+
219
+ try:
220
+ results = await self.client.search(
221
+ index=self.collection_name,
222
+ query={"knn": {"field": "embedding", "query_vector": embedding.tolist(), "k": k}},
223
+ min_score=score_threshold,
224
+ size=k,
225
+ source={"exclude_vectors": False}, # Retrieve the embedding
226
+ ignore_unavailable=True, # In case the index does not exist
227
+ )
228
+ except Exception as e:
229
+ log.error(f"Error performing vector query on Elasticsearch index {self.collection_name}: {e}")
230
+ raise
231
+
232
+ return await self._results_to_chunks(results)
233
+
234
+ async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
235
+ """Keyword search using match query."""
236
+
237
+ try:
238
+ results = await self.client.search(
239
+ index=self.collection_name,
240
+ query={"match": {"content": {"query": query_string}}},
241
+ min_score=score_threshold,
242
+ size=k,
243
+ source={"exclude_vectors": False}, # Retrieve the embedding
244
+ ignore_unavailable=True, # In case the index does not exist
245
+ )
246
+ except Exception as e:
247
+ log.error(f"Error performing keyword query on Elasticsearch index {self.collection_name}: {e}")
248
+ raise
249
+
250
+ return await self._results_to_chunks(results)
251
+
252
+ async def query_hybrid(
253
+ self,
254
+ embedding: NDArray,
255
+ query_string: str,
256
+ k: int,
257
+ score_threshold: float,
258
+ reranker_type: str,
259
+ reranker_params: dict[str, Any] | None = None,
260
+ ) -> QueryChunksResponse:
261
+ supported_retrievers = ["rrf", "linear"]
262
+ original_reranker_type = reranker_type
263
+ if reranker_type == "weighted":
264
+ log.warning("Elasticsearch does not support 'weighted' reranker; using 'linear' retriever instead.")
265
+ reranker_type = "linear"
266
+ if reranker_type not in supported_retrievers:
267
+ log.warning(
268
+ f"Unsupported reranker type: {reranker_type}. Supported types are: {supported_retrievers}. "
269
+ "Falling back to 'rrf'."
270
+ )
271
+ reranker_type = "rrf"
272
+
273
+ retriever = {
274
+ reranker_type: {
275
+ "retrievers": [
276
+ {"retriever": {"standard": {"query": {"match": {"content": query_string}}}}},
277
+ {
278
+ "retriever": {
279
+ "knn": {
280
+ "field": "embedding",
281
+ "query_vector": embedding.tolist(),
282
+ "k": k,
283
+ "num_candidates": k,
284
+ }
285
+ }
286
+ },
287
+ ]
288
+ }
289
+ }
290
+ # Elasticsearch requires rank_window_size >= size for rrf/linear retrievers.
291
+ retriever[reranker_type]["rank_window_size"] = k
292
+
293
+ # Add reranker parameters if provided for RRF (e.g. rank_constant, rank_window_size, filter)
294
+ # see https://www.elastic.co/docs/reference/elasticsearch/rest-apis/retrievers/rrf-retriever
295
+ if reranker_type == "rrf" and reranker_params is not None:
296
+ allowed_rrf_params = {"rank_constant", "rank_windows_size", "filter"}
297
+ rrf_params = dict(reranker_params)
298
+ if "impact_factor" in rrf_params:
299
+ if "rank_constant" not in rrf_params:
300
+ rrf_params["rank_constant"] = rrf_params.pop("impact_factor")
301
+ log.warning("Elasticsearch RRF does not support impact_factor; mapping to rank_constant.")
302
+ else:
303
+ rrf_params.pop("impact_factor")
304
+ log.warning("Elasticsearch RRF ignores impact_factor when rank_constant is provided.")
305
+ if "rank_window_size" not in rrf_params and "rank_windows_size" in rrf_params:
306
+ rrf_params["rank_window_size"] = rrf_params.pop("rank_windows_size")
307
+ extra_keys = set(rrf_params.keys()) - allowed_rrf_params
308
+ if extra_keys:
309
+ log.warning(f"Ignoring unsupported RRF parameters for Elasticsearch: {extra_keys}")
310
+ for key in extra_keys:
311
+ rrf_params.pop(key, None)
312
+ if rrf_params:
313
+ retriever["rrf"].update(rrf_params)
314
+ elif reranker_type == "linear" and reranker_params is not None:
315
+ # Add reranker parameters (i.e. weights) for linear
316
+ # see https://www.elastic.co/docs/reference/elasticsearch/rest-apis/retrievers/linear-retriever
317
+ if await self._is_rerank_linear_param_valid(reranker_params) is False:
318
+ converted_params = self._convert_to_linear_params(reranker_params)
319
+ if converted_params is None:
320
+ log.warning(
321
+ "Invalid linear reranker parameters for Elasticsearch; "
322
+ 'expected {"retrievers": {"standard": {"weight": float}, "knn": {"weight": float}}}. '
323
+ "Ignoring provided parameters."
324
+ )
325
+ else:
326
+ reranker_params = converted_params
327
+ try:
328
+ if await self._is_rerank_linear_param_valid(reranker_params):
329
+ retriever["linear"]["retrievers"][0].update(reranker_params["retrievers"]["standard"])
330
+ retriever["linear"]["retrievers"][1].update(reranker_params["retrievers"]["knn"])
331
+ except Exception as e:
332
+ log.error(f"Error updating linear retrievers parameters: {e}")
333
+ raise
334
+ elif reranker_type == "linear" and reranker_params is None and original_reranker_type == "weighted":
335
+ converted_params = self._convert_to_linear_params({})
336
+ if converted_params:
337
+ retriever["linear"]["retrievers"][0].update(converted_params["retrievers"]["standard"])
338
+ retriever["linear"]["retrievers"][1].update(converted_params["retrievers"]["knn"])
339
+ try:
340
+ results = await self.client.search(
341
+ index=self.collection_name,
342
+ size=k,
343
+ retriever=retriever,
344
+ min_score=score_threshold,
345
+ source={"exclude_vectors": False}, # Retrieve the embedding
346
+ ignore_unavailable=True, # In case the index does not exist
347
+ )
348
+ except Exception as e:
349
+ log.error(f"Error performing hybrid query on Elasticsearch index {self.collection_name}: {e}")
350
+ raise
351
+
352
+ return await self._results_to_chunks(results)
353
+
354
+ async def delete(self):
355
+ """Delete the entire Elasticsearch index with collection_name."""
356
+
357
+ try:
358
+ await self.client.indices.delete(index=self.collection_name, ignore_unavailable=True)
359
+ except Exception as e:
360
+ log.error(f"Error deleting Elasticsearch index {self.collection_name}: {e}")
361
+ raise
362
+
363
+
364
+ class ElasticsearchVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
365
+ def __init__(
366
+ self,
367
+ config: ElasticsearchVectorIOConfig,
368
+ inference_api: Inference,
369
+ files_api: Files | None = None,
370
+ ) -> None:
371
+ super().__init__(inference_api=inference_api, files_api=files_api, kvstore=None)
372
+ self.config = config
373
+ self.client: AsyncElasticsearch = None
374
+ self.cache = {}
375
+ self.vector_store_table = None
376
+ self.metadata_collection_name = "openai_vector_stores_metadata"
377
+
378
+ async def initialize(self) -> None:
379
+ self.client = AsyncElasticsearch(hosts=self.config.elasticsearch_url, api_key=self.config.elasticsearch_api_key)
380
+ self.kvstore = await kvstore_impl(self.config.persistence)
381
+
382
+ start_key = VECTOR_DBS_PREFIX
383
+ end_key = f"{VECTOR_DBS_PREFIX}\xff"
384
+ stored_vector_stores = await self.kvstore.values_in_range(start_key, end_key)
385
+
386
+ for vector_store_data in stored_vector_stores:
387
+ vector_store = VectorStore.model_validate_json(vector_store_data)
388
+ index = VectorStoreWithIndex(
389
+ vector_store, ElasticsearchIndex(self.client, vector_store.identifier), self.inference_api
390
+ )
391
+ self.cache[vector_store.identifier] = index
392
+ self.openai_vector_stores = await self._load_openai_vector_stores()
393
+
394
+ async def shutdown(self) -> None:
395
+ await self.client.close()
396
+ # Clean up mixin resources (file batch tasks)
397
+ await super().shutdown()
398
+
399
+ async def register_vector_store(self, vector_store: VectorStore) -> None:
400
+ assert self.kvstore is not None
401
+ key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
402
+ await self.kvstore.set(key=key, value=vector_store.model_dump_json())
403
+
404
+ index = VectorStoreWithIndex(
405
+ vector_store=vector_store,
406
+ index=ElasticsearchIndex(self.client, vector_store.identifier),
407
+ inference_api=self.inference_api,
408
+ )
409
+
410
+ self.cache[vector_store.identifier] = index
411
+
412
+ async def unregister_vector_store(self, vector_store_id: str) -> None:
413
+ if vector_store_id in self.cache:
414
+ await self.cache[vector_store_id].index.delete()
415
+ del self.cache[vector_store_id]
416
+
417
+ assert self.kvstore is not None
418
+ await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}")
419
+
420
+ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> VectorStoreWithIndex | None:
421
+ if vector_store_id in self.cache:
422
+ return self.cache[vector_store_id]
423
+
424
+ if self.vector_store_table is None:
425
+ raise ValueError(f"Vector DB not found {vector_store_id}")
426
+
427
+ vector_store = await self.vector_store_table.get_vector_store(vector_store_id)
428
+ if not vector_store:
429
+ raise VectorStoreNotFoundError(vector_store_id)
430
+
431
+ index = VectorStoreWithIndex(
432
+ vector_store=vector_store,
433
+ index=ElasticsearchIndex(client=self.client, collection_name=vector_store.identifier),
434
+ inference_api=self.inference_api,
435
+ )
436
+ self.cache[vector_store_id] = index
437
+ return index
438
+
439
+ async def insert_chunks(
440
+ self, vector_store_id: str, chunks: list[EmbeddedChunk], ttl_seconds: int | None = None
441
+ ) -> None:
442
+ index = await self._get_and_cache_vector_store_index(vector_store_id)
443
+ if not index:
444
+ raise VectorStoreNotFoundError(vector_store_id)
445
+
446
+ await index.insert_chunks(chunks)
447
+
448
+ async def query_chunks(
449
+ self, vector_store_id: str, query: InterleavedContent, params: dict[str, Any] | None = None
450
+ ) -> QueryChunksResponse:
451
+ index = await self._get_and_cache_vector_store_index(vector_store_id)
452
+ if not index:
453
+ raise VectorStoreNotFoundError(vector_store_id)
454
+
455
+ return await index.query_chunks(query, params)
456
+
457
+ async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
458
+ """Delete chunks from an Elasticsearch vector store."""
459
+ index = await self._get_and_cache_vector_store_index(store_id)
460
+ if not index:
461
+ raise ValueError(f"Vector DB {store_id} not found")
462
+
463
+ await index.index.delete_chunks(chunks_for_deletion)
@@ -0,0 +1,22 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from llama_stack.providers.remote.vector_io.oci.config import OCI26aiVectorIOConfig
8
+ from llama_stack_api import Api, ProviderSpec
9
+
10
+
11
+ async def get_adapter_impl(config: OCI26aiVectorIOConfig, deps: dict[Api, ProviderSpec]):
12
+ from typing import cast
13
+
14
+ from llama_stack.providers.remote.vector_io.oci.oci26ai import OCI26aiVectorIOAdapter
15
+ from llama_stack_api import Files, Inference
16
+
17
+ assert isinstance(config, OCI26aiVectorIOConfig), f"Unexpected config type: {type(config)}"
18
+ inference_api = cast(Inference, deps[Api.inference])
19
+ files_api = cast(Files | None, deps.get(Api.files))
20
+ impl = OCI26aiVectorIOAdapter(config, inference_api, files_api)
21
+ await impl.initialize()
22
+ return impl
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Any
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+ from llama_stack.core.storage.datatypes import KVStoreReference
12
+ from llama_stack_api import json_schema_type
13
+
14
+
15
+ @json_schema_type
16
+ class OCI26aiVectorIOConfig(BaseModel):
17
+ conn_str: str = Field(description="Connection string for the given 26ai Service")
18
+ user: str = Field(description="Username name to connect to the service")
19
+ password: str = Field(description="Password to connect to the service")
20
+ tnsnames_loc: str = Field(description="Directory location of the tsnanames.ora file")
21
+ ewallet_pem_loc: str = Field(description="Directory location of the ewallet.pem file")
22
+ ewallet_password: str = Field(description="Password for the ewallet.pem file")
23
+ persistence: KVStoreReference = Field(description="Config for KV store backend")
24
+ consistency_level: str = Field(description="The consistency level of the OCI26ai server", default="Strong")
25
+ vector_datatype: str = Field(description="Vector datatype for embeddings", default="FLOAT32")
26
+
27
+ @classmethod
28
+ def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
29
+ return {
30
+ "conn_str": "${env.OCI26AI_CONNECTION_STRING}",
31
+ "user": "${env.OCI26AI_USER}",
32
+ "password": "${env.OCI26AI_PASSWORD}",
33
+ "tnsnames_loc": "${env.OCI26AI_TNSNAMES_LOC}",
34
+ "ewallet_pem_loc": "${env.OCI26AI_EWALLET_PEM_LOC}",
35
+ "ewallet_password": "${env.OCI26AI_EWALLET_PWD}",
36
+ "vector_datatype": "${env.OCI26AI_VECTOR_DATATYPE:=FLOAT32}",
37
+ "persistence": KVStoreReference(
38
+ backend="kv_default",
39
+ namespace="vector_io::oci26ai",
40
+ ).model_dump(exclude_none=True),
41
+ }