llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  57. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  58. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  59. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  60. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  61. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  62. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  63. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  64. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  65. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  66. llama_stack/providers/registry/agents.py +1 -0
  67. llama_stack/providers/registry/inference.py +1 -9
  68. llama_stack/providers/registry/vector_io.py +136 -16
  69. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  70. llama_stack/providers/remote/files/s3/config.py +5 -3
  71. llama_stack/providers/remote/files/s3/files.py +2 -2
  72. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  73. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  74. llama_stack/providers/remote/inference/together/together.py +4 -0
  75. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  76. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  77. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  78. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  79. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  80. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  81. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  82. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  83. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  84. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  85. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  86. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  87. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  88. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  89. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  90. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  91. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  92. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  93. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  94. llama_stack/providers/utils/bedrock/client.py +3 -3
  95. llama_stack/providers/utils/bedrock/config.py +7 -7
  96. llama_stack/providers/utils/inference/__init__.py +0 -25
  97. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  98. llama_stack/providers/utils/inference/http_client.py +239 -0
  99. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  100. llama_stack/providers/utils/inference/model_registry.py +148 -2
  101. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  102. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  103. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  104. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  105. llama_stack/providers/utils/memory/vector_store.py +46 -19
  106. llama_stack/providers/utils/responses/responses_store.py +7 -7
  107. llama_stack/providers/utils/safety.py +114 -0
  108. llama_stack/providers/utils/tools/mcp.py +44 -3
  109. llama_stack/testing/api_recorder.py +9 -3
  110. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  111. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
  112. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  113. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  114. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  115. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  116. llama_stack/models/llama/hadamard_utils.py +0 -88
  117. llama_stack/models/llama/llama3/args.py +0 -74
  118. llama_stack/models/llama/llama3/dog.jpg +0 -0
  119. llama_stack/models/llama/llama3/generation.py +0 -378
  120. llama_stack/models/llama/llama3/model.py +0 -304
  121. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  122. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  123. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  124. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  125. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  126. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  127. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  128. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  129. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  130. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  131. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  132. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  133. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  134. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  135. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  136. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  137. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  138. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  139. llama_stack/models/llama/llama4/args.py +0 -107
  140. llama_stack/models/llama/llama4/ffn.py +0 -58
  141. llama_stack/models/llama/llama4/moe.py +0 -214
  142. llama_stack/models/llama/llama4/preprocess.py +0 -435
  143. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  144. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  145. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  146. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  147. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  148. llama_stack/models/llama/quantize_impls.py +0 -316
  149. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  150. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  151. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  152. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  153. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  154. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  155. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  156. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  157. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  158. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
  159. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
@@ -10,11 +10,16 @@ from abc import ABC, abstractmethod
10
10
  from collections.abc import AsyncIterator, Iterable
11
11
  from typing import Any
12
12
 
13
+ import httpx
13
14
  from openai import AsyncOpenAI
14
15
  from pydantic import BaseModel, ConfigDict
15
16
 
16
17
  from llama_stack.core.request_headers import NeedsRequestProviderData
17
18
  from llama_stack.log import get_logger
19
+ from llama_stack.providers.utils.inference.http_client import (
20
+ _build_network_client_kwargs,
21
+ _merge_network_config_into_client,
22
+ )
18
23
  from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
19
24
  from llama_stack.providers.utils.inference.openai_compat import (
20
25
  get_stream_options_for_telemetry,
@@ -34,6 +39,7 @@ from llama_stack_api import (
34
39
  OpenAIEmbeddingsResponse,
35
40
  OpenAIEmbeddingUsage,
36
41
  OpenAIMessageParam,
42
+ validate_embeddings_input_is_text,
37
43
  )
38
44
 
39
45
  logger = get_logger(name=__name__, category="providers::utils")
@@ -82,6 +88,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
82
88
  # Set to False for providers that don't support stream_options (e.g., Ollama, vLLM)
83
89
  supports_stream_options: bool = True
84
90
 
91
+ # Allow subclasses to control whether the provider supports tokenized embeddings input
92
+ # Set to True for providers that support pre-tokenized input (list[int] and list[list[int]])
93
+ supports_tokenized_embeddings_input: bool = False
94
+
85
95
  # Embedding model metadata for this provider
86
96
  # Can be set by subclasses or instances to provide embedding models
87
97
  # Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
@@ -121,7 +131,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
121
131
  Get any extra parameters to pass to the AsyncOpenAI client.
122
132
 
123
133
  Child classes can override this method to provide additional parameters
124
- such as timeout settings, proxies, etc.
134
+ such as custom http_client, timeout settings, proxies, etc.
135
+
136
+ Note: Network configuration from config.network is automatically applied
137
+ in the client property. This method is for provider-specific customizations.
125
138
 
126
139
  :return: A dictionary of extra parameters
127
140
  """
@@ -194,6 +207,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
194
207
  Uses the abstract methods get_api_key() and get_base_url() which must be
195
208
  implemented by child classes.
196
209
 
210
+ Network configuration from config.network is automatically applied.
197
211
  Users can also provide the API key via the provider data header, which
198
212
  is used instead of any config API key.
199
213
  """
@@ -205,10 +219,30 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
205
219
  message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
206
220
  raise ValueError(message)
207
221
 
222
+ extra_params = self.get_extra_client_params()
223
+ network_kwargs = _build_network_client_kwargs(self.config.network)
224
+
225
+ # Handle http_client creation/merging:
226
+ # - If get_extra_client_params() provides an http_client (e.g., OCI with custom auth),
227
+ # merge network config into it. The merge behavior:
228
+ # * Preserves auth from get_extra_client_params() (provider-specific auth like OCI signer)
229
+ # * Preserves headers from get_extra_client_params() as base
230
+ # * Applies network config (TLS, proxy, timeout, headers) on top
231
+ # * Network config headers take precedence over provider headers (allows override)
232
+ # - Otherwise, if network config exists, create http_client from it
233
+ # This allows providers with custom auth to still use standard network settings
234
+ if "http_client" in extra_params:
235
+ if network_kwargs:
236
+ extra_params["http_client"] = _merge_network_config_into_client(
237
+ extra_params["http_client"], self.config.network
238
+ )
239
+ elif network_kwargs:
240
+ extra_params["http_client"] = httpx.AsyncClient(**network_kwargs)
241
+
208
242
  return AsyncOpenAI(
209
243
  api_key=api_key,
210
244
  base_url=self.get_base_url(),
211
- **self.get_extra_client_params(),
245
+ **extra_params,
212
246
  )
213
247
 
214
248
  def _get_api_key_from_config_or_provider_data(self) -> str | None:
@@ -371,6 +405,8 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
371
405
  top_logprobs=params.top_logprobs,
372
406
  top_p=params.top_p,
373
407
  user=params.user,
408
+ safety_identifier=params.safety_identifier,
409
+ reasoning_effort=params.reasoning_effort,
374
410
  )
375
411
 
376
412
  if extra_body := params.model_extra:
@@ -386,6 +422,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
386
422
  """
387
423
  Direct OpenAI embeddings API call.
388
424
  """
425
+ # Validate token array support if provider doesn't support it
426
+ if not self.supports_tokenized_embeddings_input:
427
+ validate_embeddings_input_is_text(params)
428
+
389
429
  provider_model_id = await self._get_provider_model_id(params.model)
390
430
  self._validate_model_allowed(provider_model_id)
391
431
 
@@ -4,63 +4,24 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- import asyncio
8
7
  import base64
9
- import io
10
- import json
11
8
  import re
12
9
  from typing import Any
13
10
 
14
11
  import httpx
15
- from PIL import Image as PIL_Image
16
12
 
17
13
  from llama_stack.log import get_logger
18
- from llama_stack.models.llama.datatypes import (
19
- RawContent,
20
- RawContentItem,
21
- RawMediaItem,
22
- RawMessage,
23
- RawTextItem,
24
- StopReason,
25
- ToolCall,
26
- ToolDefinition,
27
- ToolPromptFormat,
28
- )
29
- from llama_stack.models.llama.llama3.chat_format import ChatFormat
30
- from llama_stack.models.llama.llama3.tokenizer import Tokenizer
31
- from llama_stack.models.llama.sku_list import resolve_model
32
- from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
33
14
  from llama_stack_api import (
34
- CompletionRequest,
35
15
  ImageContentItem,
36
- InterleavedContent,
37
- InterleavedContentItem,
38
- OpenAIAssistantMessageParam,
39
16
  OpenAIChatCompletionContentPartImageParam,
40
17
  OpenAIChatCompletionContentPartTextParam,
41
18
  OpenAIFile,
42
- OpenAIMessageParam,
43
- OpenAISystemMessageParam,
44
- OpenAIToolMessageParam,
45
- OpenAIUserMessageParam,
46
- ResponseFormat,
47
- ResponseFormatType,
48
19
  TextContentItem,
49
- ToolChoice,
50
20
  )
51
21
 
52
22
  log = get_logger(name=__name__, category="providers::utils")
53
23
 
54
24
 
55
- class CompletionRequestWithRawContent(CompletionRequest):
56
- content: RawContent
57
-
58
-
59
- def decode_assistant_message(content: str, stop_reason: StopReason) -> RawMessage:
60
- formatter = ChatFormat(Tokenizer.get_instance())
61
- return formatter.decode_assistant_message_from_content(content, stop_reason)
62
-
63
-
64
25
  def interleaved_content_as_str(
65
26
  content: Any,
66
27
  sep: str = " ",
@@ -86,92 +47,6 @@ def interleaved_content_as_str(
86
47
  return _process(content)
87
48
 
88
49
 
89
- async def interleaved_content_convert_to_raw(
90
- content: InterleavedContent,
91
- ) -> RawContent:
92
- """Download content from URLs / files etc. so plain bytes can be sent to the model"""
93
-
94
- async def _localize_single(c: str | InterleavedContentItem) -> str | RawContentItem:
95
- if isinstance(c, str):
96
- return RawTextItem(text=c)
97
- elif isinstance(c, TextContentItem):
98
- return RawTextItem(text=c.text)
99
- elif isinstance(c, ImageContentItem):
100
- image = c.image
101
- if image.url:
102
- # Load image bytes from URL
103
- if image.url.uri.startswith("data"):
104
- match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
105
- if not match:
106
- raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
107
- _, image_data = match.groups()
108
- data = base64.b64decode(image_data)
109
- elif image.url.uri.startswith("file://"):
110
- path = image.url.uri[len("file://") :]
111
- with open(path, "rb") as f:
112
- data = f.read() # type: ignore
113
- elif image.url.uri.startswith("http"):
114
- async with httpx.AsyncClient() as client:
115
- response = await client.get(image.url.uri)
116
- data = response.content
117
- else:
118
- raise ValueError("Unsupported URL type")
119
- elif image.data:
120
- # data is a base64 encoded string, decode it to bytes for RawMediaItem
121
- data = base64.b64decode(image.data)
122
- else:
123
- raise ValueError("No data or URL provided")
124
-
125
- return RawMediaItem(data=data)
126
- else:
127
- raise ValueError(f"Unsupported content type: {type(c)}")
128
-
129
- if isinstance(content, list):
130
- return await asyncio.gather(*(_localize_single(c) for c in content))
131
- else:
132
- return await _localize_single(content)
133
-
134
-
135
- async def convert_openai_message_to_raw_message(message: OpenAIMessageParam) -> RawMessage:
136
- """Convert OpenAI message format to RawMessage format used by Llama formatters."""
137
- if isinstance(message, OpenAIUserMessageParam):
138
- content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
139
- return RawMessage(role="user", content=content)
140
- elif isinstance(message, OpenAISystemMessageParam):
141
- content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
142
- return RawMessage(role="system", content=content)
143
- elif isinstance(message, OpenAIAssistantMessageParam):
144
- content = await interleaved_content_convert_to_raw(message.content or "") # type: ignore[arg-type]
145
- tool_calls = []
146
- if message.tool_calls:
147
- for tc in message.tool_calls:
148
- if tc.function:
149
- tool_calls.append(
150
- ToolCall(
151
- call_id=tc.id or "",
152
- tool_name=tc.function.name or "",
153
- arguments=tc.function.arguments or "{}",
154
- )
155
- )
156
- return RawMessage(role="assistant", content=content, tool_calls=tool_calls)
157
- elif isinstance(message, OpenAIToolMessageParam):
158
- content = await interleaved_content_convert_to_raw(message.content) # type: ignore[arg-type]
159
- return RawMessage(role="tool", content=content)
160
- else:
161
- # Handle OpenAIDeveloperMessageParam if needed
162
- raise ValueError(f"Unsupported message type: {type(message)}")
163
-
164
-
165
- def content_has_media(content: InterleavedContent):
166
- def _has_media_content(c):
167
- return isinstance(c, ImageContentItem)
168
-
169
- if isinstance(content, list):
170
- return any(_has_media_content(c) for c in content)
171
- else:
172
- return _has_media_content(content)
173
-
174
-
175
50
  async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
176
51
  if uri.startswith("http"):
177
52
  async with httpx.AsyncClient() as client:
@@ -194,87 +69,3 @@ async def localize_image_content(uri: str) -> tuple[bytes, str] | None:
194
69
  return content, fmt
195
70
  else:
196
71
  return None
197
-
198
-
199
- async def convert_image_content_to_url(
200
- media: ImageContentItem, download: bool = False, include_format: bool = True
201
- ) -> str:
202
- image = media.image
203
- if image.url and (not download or image.url.uri.startswith("data")):
204
- return image.url.uri
205
-
206
- if image.data:
207
- # data is a base64 encoded string, decode it to bytes first
208
- # TODO(mf): do this more efficiently, decode less
209
- content = base64.b64decode(image.data)
210
- pil_image = PIL_Image.open(io.BytesIO(content))
211
- format = pil_image.format
212
- else:
213
- localize_result = await localize_image_content(image.url.uri)
214
- if localize_result is None:
215
- raise ValueError(f"Failed to localize image content from {image.url.uri}")
216
- content, format = localize_result
217
-
218
- if include_format:
219
- return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
220
- else:
221
- return base64.b64encode(content).decode("utf-8")
222
-
223
-
224
- def augment_content_with_response_format_prompt(response_format, content):
225
- if fmt_prompt := response_format_prompt(response_format):
226
- if isinstance(content, list):
227
- return content + [TextContentItem(text=fmt_prompt)]
228
- elif isinstance(content, str):
229
- return [TextContentItem(text=content), TextContentItem(text=fmt_prompt)]
230
- else:
231
- return [content, TextContentItem(text=fmt_prompt)]
232
-
233
- return content
234
-
235
-
236
- def response_format_prompt(fmt: ResponseFormat | None):
237
- if not fmt:
238
- return None
239
-
240
- if fmt.type == ResponseFormatType.json_schema.value:
241
- return f"Please respond in JSON format with the schema: {json.dumps(fmt.json_schema)}"
242
- elif fmt.type == ResponseFormatType.grammar.value:
243
- raise NotImplementedError("Grammar response format not supported yet")
244
- else:
245
- raise ValueError(f"Unknown response format {fmt.type}")
246
-
247
-
248
- def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: list[ToolDefinition]) -> str:
249
- if tool_choice == ToolChoice.auto:
250
- return ""
251
- elif tool_choice == ToolChoice.required:
252
- return "You MUST use one of the provided functions/tools to answer the user query."
253
- elif tool_choice == ToolChoice.none:
254
- # tools are already not passed in
255
- return ""
256
- else:
257
- # specific tool
258
- return f"You MUST use the tool `{tool_choice}` to answer the user query."
259
-
260
-
261
- def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
262
- llama_model = resolve_model(model)
263
- if llama_model is None:
264
- log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
265
- return ToolPromptFormat.json
266
-
267
- if llama_model.model_family == ModelFamily.llama3_1 or (
268
- llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
269
- ):
270
- # llama3.1 and llama3.2 multimodal models follow the same tool prompt format
271
- return ToolPromptFormat.json
272
- elif llama_model.model_family in (
273
- ModelFamily.llama3_2,
274
- ModelFamily.llama3_3,
275
- ModelFamily.llama4,
276
- ):
277
- # llama3.2 and llama3.3 models follow the same tool prompt format
278
- return ToolPromptFormat.python_list
279
- else:
280
- return ToolPromptFormat.json
@@ -671,6 +671,19 @@ class OpenAIVectorStoreMixin(ABC):
671
671
  search_query = query
672
672
 
673
673
  try:
674
+ # Validate neural ranker requires model parameter
675
+ if ranking_options is not None:
676
+ if getattr(ranking_options, "ranker", None) == "neural":
677
+ model_value = getattr(ranking_options, "model", None)
678
+ if model_value is None or (isinstance(model_value, str) and model_value.strip() == ""):
679
+ # Return empty results when model is missing for neural ranker
680
+ logger.warning("model parameter is required when ranker='neural', returning empty results")
681
+ return VectorStoreSearchResponsePage(
682
+ search_query=query if isinstance(query, list) else [query],
683
+ data=[],
684
+ has_more=False,
685
+ next_page=None,
686
+ )
674
687
  score_threshold = (
675
688
  ranking_options.score_threshold
676
689
  if ranking_options and ranking_options.score_threshold is not None
@@ -681,7 +694,10 @@ class OpenAIVectorStoreMixin(ABC):
681
694
  "score_threshold": score_threshold,
682
695
  "mode": search_mode,
683
696
  }
684
- # TODO: Add support for ranking_options.ranker
697
+
698
+ # Use VectorStoresConfig defaults when ranking_options values are not provided
699
+ config = self.vector_stores_config or VectorStoresConfig()
700
+ params.update(self._build_reranker_params(ranking_options, config))
685
701
 
686
702
  response = await self.query_chunks(
687
703
  vector_store_id=vector_store_id,
@@ -722,8 +738,8 @@ class OpenAIVectorStoreMixin(ABC):
722
738
  )
723
739
 
724
740
  except Exception as e:
741
+ # Log the error and return empty results
725
742
  logger.error(f"Error searching vector store {vector_store_id}: {e}")
726
- # Return empty results on error
727
743
  return VectorStoreSearchResponsePage(
728
744
  search_query=query if isinstance(query, list) else [query],
729
745
  data=[],
@@ -731,6 +747,62 @@ class OpenAIVectorStoreMixin(ABC):
731
747
  next_page=None,
732
748
  )
733
749
 
750
+ def _build_reranker_params(
751
+ self,
752
+ ranking_options: SearchRankingOptions | None,
753
+ config: VectorStoresConfig,
754
+ ) -> dict[str, Any]:
755
+ reranker_params: dict[str, Any] = {}
756
+ params: dict[str, Any] = {}
757
+
758
+ if ranking_options and ranking_options.ranker:
759
+ reranker_type = ranking_options.ranker
760
+
761
+ if ranking_options.ranker == "weighted":
762
+ alpha = ranking_options.alpha
763
+ if alpha is None:
764
+ alpha = config.chunk_retrieval_params.weighted_search_alpha
765
+ reranker_params["alpha"] = alpha
766
+ if ranking_options.weights:
767
+ reranker_params["weights"] = ranking_options.weights
768
+ elif ranking_options.ranker == "rrf":
769
+ # For RRF ranker, use impact_factor from request if provided, otherwise use VectorStoresConfig default
770
+ impact_factor = ranking_options.impact_factor
771
+ if impact_factor is None:
772
+ impact_factor = config.chunk_retrieval_params.rrf_impact_factor
773
+ reranker_params["impact_factor"] = impact_factor
774
+ # If weights dict is provided (for neural combination), store it
775
+ if ranking_options.weights:
776
+ reranker_params["weights"] = ranking_options.weights
777
+ elif ranking_options.ranker == "neural":
778
+ reranker_params["model"] = ranking_options.model
779
+ else:
780
+ logger.debug(f"Unknown ranker value: {ranking_options.ranker}, passing through")
781
+
782
+ params["reranker_type"] = reranker_type
783
+ params["reranker_params"] = reranker_params
784
+
785
+ # Store model and weights for neural reranking (TODO: implemented in Part II)
786
+ if ranking_options.model:
787
+ params["neural_model"] = ranking_options.model
788
+ if ranking_options.weights:
789
+ params["neural_weights"] = ranking_options.weights
790
+ elif ranking_options is None or ranking_options.ranker is None:
791
+ # No ranker specified in request - use VectorStoresConfig default
792
+ default_strategy = config.chunk_retrieval_params.default_reranker_strategy
793
+ if default_strategy in ("weighted", "rrf"):
794
+ params["reranker_type"] = default_strategy
795
+ reranker_params = {}
796
+
797
+ if default_strategy == "weighted":
798
+ reranker_params["alpha"] = config.chunk_retrieval_params.weighted_search_alpha
799
+ elif default_strategy == "rrf":
800
+ reranker_params["impact_factor"] = config.chunk_retrieval_params.rrf_impact_factor
801
+
802
+ params["reranker_params"] = reranker_params
803
+
804
+ return params
805
+
734
806
  def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
735
807
  """Check if metadata matches the provided filters."""
736
808
  if not filters:
@@ -738,15 +810,29 @@ class OpenAIVectorStoreMixin(ABC):
738
810
 
739
811
  filter_type = filters.get("type")
740
812
 
813
+ if filter_type is None:
814
+ if "key" not in filters and "value" not in filters and "filters" not in filters:
815
+ for key, value in filters.items():
816
+ if key not in metadata:
817
+ return False
818
+ if metadata[key] != value:
819
+ return False
820
+ return True
821
+ else:
822
+ raise ValueError("Unsupported filter structure: missing 'type' field")
823
+
741
824
  if filter_type in ["eq", "ne", "gt", "gte", "lt", "lte"]:
742
825
  # Comparison filter
743
- key = filters.get("key")
826
+ filter_key = filters.get("key")
744
827
  value = filters.get("value")
745
828
 
746
- if key not in metadata:
829
+ if filter_key is None or not isinstance(filter_key, str):
830
+ return False
831
+
832
+ if filter_key not in metadata:
747
833
  return False
748
834
 
749
- metadata_value = metadata[key]
835
+ metadata_value = metadata[filter_key]
750
836
 
751
837
  if filter_type == "eq":
752
838
  return bool(metadata_value == value)
@@ -901,6 +987,7 @@ class OpenAIVectorStoreMixin(ABC):
901
987
  params = OpenAIEmbeddingsRequestWithExtraBody(
902
988
  model=embedding_model,
903
989
  input=[interleaved_content_as_str(c.content) for c in chunks],
990
+ dimensions=embedding_dimension,
904
991
  )
905
992
  resp = await self.inference_api.openai_embeddings(params)
906
993
 
@@ -297,37 +297,64 @@ class VectorStoreWithIndex:
297
297
  mode = params.get("mode")
298
298
  score_threshold = params.get("score_threshold", 0.0)
299
299
 
300
- ranker = params.get("ranker")
301
- if ranker is None:
300
+ # Get reranker configuration from params (set by openai_vector_store_mixin)
301
+ # NOTE: Breaking change - removed support for old nested "ranker" format.
302
+ # Now uses flattened format: reranker_type and reranker_params.
303
+ reranker_type = params.get("reranker_type")
304
+ reranker_params = params.get("reranker_params", {})
305
+
306
+ # If no ranker specified, use VectorStoresConfig default
307
+ if reranker_type is None:
302
308
  reranker_type = (
303
309
  RERANKER_TYPE_RRF
304
310
  if config.chunk_retrieval_params.default_reranker_strategy == "rrf"
305
311
  else config.chunk_retrieval_params.default_reranker_strategy
306
312
  )
307
313
  reranker_params = {"impact_factor": config.chunk_retrieval_params.rrf_impact_factor}
314
+
315
+ # Normalize reranker_type to use constants
316
+ if reranker_type == "weighted":
317
+ reranker_type = RERANKER_TYPE_WEIGHTED
318
+ # Ensure alpha is set (use default if not provided)
319
+ if "alpha" not in reranker_params:
320
+ reranker_params["alpha"] = config.chunk_retrieval_params.weighted_search_alpha
321
+ elif reranker_type == "rrf":
322
+ reranker_type = RERANKER_TYPE_RRF
323
+ # Ensure impact_factor is set (use default if not provided)
324
+ if "impact_factor" not in reranker_params:
325
+ reranker_params["impact_factor"] = config.chunk_retrieval_params.rrf_impact_factor
326
+ elif reranker_type == "neural":
327
+ # TODO: Implement neural reranking
328
+ log.warning(
329
+ "TODO: Neural reranking for vector stores is not implemented yet; "
330
+ "using configured reranker params without algorithm fallback."
331
+ )
332
+ elif reranker_type == "normalized":
333
+ reranker_type = RERANKER_TYPE_NORMALIZED
308
334
  else:
309
- strategy = ranker.get("strategy", config.chunk_retrieval_params.default_reranker_strategy)
310
- if strategy == "weighted":
311
- weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
312
- reranker_type = RERANKER_TYPE_WEIGHTED
313
- reranker_params = {
314
- "alpha": weights[0] if len(weights) > 0 else config.chunk_retrieval_params.weighted_search_alpha
315
- }
316
- elif strategy == "normalized":
317
- reranker_type = RERANKER_TYPE_NORMALIZED
318
- else:
319
- reranker_type = RERANKER_TYPE_RRF
320
- k_value = ranker.get("params", {}).get("k", config.chunk_retrieval_params.rrf_impact_factor)
321
- reranker_params = {"impact_factor": k_value}
335
+ # Default to RRF for unknown strategies
336
+ reranker_type = RERANKER_TYPE_RRF
337
+ if "impact_factor" not in reranker_params:
338
+ reranker_params["impact_factor"] = config.chunk_retrieval_params.rrf_impact_factor
339
+
340
+ # Store neural model and weights from params if provided (for future neural reranking in Part II)
341
+ if "neural_model" in params:
342
+ reranker_params["neural_model"] = params["neural_model"]
343
+ if "neural_weights" in params:
344
+ reranker_params["neural_weights"] = params["neural_weights"]
322
345
 
323
346
  query_string = interleaved_content_as_str(query)
324
347
  if mode == "keyword":
325
348
  return await self.index.query_keyword(query_string, k, score_threshold)
326
349
 
327
- params = OpenAIEmbeddingsRequestWithExtraBody(
328
- model=self.vector_store.embedding_model,
329
- input=[query_string],
330
- )
350
+ if "embedding_dimensions" in params:
351
+ params = OpenAIEmbeddingsRequestWithExtraBody(
352
+ model=self.vector_store.embedding_model,
353
+ input=[query_string],
354
+ dimensions=params.get("embedding_dimensions"),
355
+ )
356
+ else:
357
+ params = OpenAIEmbeddingsRequestWithExtraBody(model=self.vector_store.embedding_model, input=[query_string])
331
358
  embeddings_response = await self.inference_api.openai_embeddings(params)
332
359
  query_vector = np.array(embeddings_response.data[0].embedding, dtype=np.float32)
333
360
  if mode == "hybrid":
@@ -57,7 +57,7 @@ class ResponsesStore:
57
57
  self.sql_store = AuthorizedSqlStore(base_store, self.policy)
58
58
 
59
59
  await self.sql_store.create_table(
60
- "openai_responses",
60
+ self.reference.table_name,
61
61
  {
62
62
  "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
63
63
  "created_at": ColumnType.INTEGER,
@@ -112,7 +112,7 @@ class ResponsesStore:
112
112
  data["messages"] = [msg.model_dump() for msg in messages]
113
113
 
114
114
  await self.sql_store.upsert(
115
- table="openai_responses",
115
+ table=self.reference.table_name,
116
116
  data={
117
117
  "id": data["id"],
118
118
  "created_at": data["created_at"],
@@ -137,7 +137,7 @@ class ResponsesStore:
137
137
  data["messages"] = [msg.model_dump() for msg in messages]
138
138
 
139
139
  await self.sql_store.insert(
140
- "openai_responses",
140
+ self.reference.table_name,
141
141
  {
142
142
  "id": data["id"],
143
143
  "created_at": data["created_at"],
@@ -172,7 +172,7 @@ class ResponsesStore:
172
172
  where_conditions["model"] = model
173
173
 
174
174
  paginated_result = await self.sql_store.fetch_all(
175
- table="openai_responses",
175
+ table=self.reference.table_name,
176
176
  where=where_conditions if where_conditions else None,
177
177
  order_by=[("created_at", order.value)],
178
178
  cursor=("id", after) if after else None,
@@ -195,7 +195,7 @@ class ResponsesStore:
195
195
  raise ValueError("Responses store is not initialized")
196
196
 
197
197
  row = await self.sql_store.fetch_one(
198
- "openai_responses",
198
+ self.reference.table_name,
199
199
  where={"id": response_id},
200
200
  )
201
201
 
@@ -210,10 +210,10 @@ class ResponsesStore:
210
210
  if not self.sql_store:
211
211
  raise ValueError("Responses store is not initialized")
212
212
 
213
- row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
213
+ row = await self.sql_store.fetch_one(self.reference.table_name, where={"id": response_id})
214
214
  if not row:
215
215
  raise ValueError(f"Response with id {response_id} not found")
216
- await self.sql_store.delete("openai_responses", where={"id": response_id})
216
+ await self.sql_store.delete(self.reference.table_name, where={"id": response_id})
217
217
  return OpenAIDeleteResponseObject(id=response_id)
218
218
 
219
219
  async def list_response_input_items(