llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +2 -1
- llama_stack/providers/utils/inference/openai_mixin.py +41 -2
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import ssl
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
from openai._base_client import DefaultAsyncHttpxClient
|
|
13
|
+
|
|
14
|
+
from llama_stack.log import get_logger
|
|
15
|
+
from llama_stack.providers.utils.inference.model_registry import (
|
|
16
|
+
NetworkConfig,
|
|
17
|
+
ProxyConfig,
|
|
18
|
+
TimeoutConfig,
|
|
19
|
+
TLSConfig,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = get_logger(name=__name__, category="providers::utils")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _build_ssl_context(tls_config: TLSConfig) -> ssl.SSLContext | bool | Path:
|
|
26
|
+
"""
|
|
27
|
+
Build an SSL context from TLS configuration.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
- ssl.SSLContext if advanced options (min_version, ciphers, or mTLS) are configured
|
|
31
|
+
- Path if only a CA bundle path is specified
|
|
32
|
+
- bool if only verify is specified as boolean
|
|
33
|
+
"""
|
|
34
|
+
has_advanced_options = (
|
|
35
|
+
tls_config.min_version is not None or tls_config.ciphers is not None or tls_config.client_cert is not None
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if not has_advanced_options:
|
|
39
|
+
return tls_config.verify
|
|
40
|
+
|
|
41
|
+
ctx = ssl.create_default_context()
|
|
42
|
+
|
|
43
|
+
if isinstance(tls_config.verify, Path):
|
|
44
|
+
ctx.load_verify_locations(str(tls_config.verify))
|
|
45
|
+
elif not tls_config.verify:
|
|
46
|
+
ctx.check_hostname = False
|
|
47
|
+
ctx.verify_mode = ssl.CERT_NONE
|
|
48
|
+
|
|
49
|
+
if tls_config.min_version:
|
|
50
|
+
if tls_config.min_version == "TLSv1.2":
|
|
51
|
+
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
52
|
+
elif tls_config.min_version == "TLSv1.3":
|
|
53
|
+
ctx.minimum_version = ssl.TLSVersion.TLSv1_3
|
|
54
|
+
|
|
55
|
+
if tls_config.ciphers:
|
|
56
|
+
ctx.set_ciphers(":".join(tls_config.ciphers))
|
|
57
|
+
|
|
58
|
+
if tls_config.client_cert and tls_config.client_key:
|
|
59
|
+
ctx.load_cert_chain(certfile=str(tls_config.client_cert), keyfile=str(tls_config.client_key))
|
|
60
|
+
|
|
61
|
+
return ctx
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _build_proxy_mounts(proxy_config: ProxyConfig) -> dict[str, httpx.AsyncHTTPTransport] | None:
|
|
65
|
+
"""
|
|
66
|
+
Build httpx proxy mounts from proxy configuration.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary of proxy mounts for httpx, or None if no proxies configured
|
|
70
|
+
"""
|
|
71
|
+
transport_kwargs: dict[str, Any] = {}
|
|
72
|
+
if proxy_config.cacert:
|
|
73
|
+
# Convert Path to string for httpx
|
|
74
|
+
transport_kwargs["verify"] = str(proxy_config.cacert)
|
|
75
|
+
|
|
76
|
+
if proxy_config.url:
|
|
77
|
+
# Convert HttpUrl to string for httpx
|
|
78
|
+
proxy_url = str(proxy_config.url)
|
|
79
|
+
return {
|
|
80
|
+
"http://": httpx.AsyncHTTPTransport(proxy=proxy_url, **transport_kwargs),
|
|
81
|
+
"https://": httpx.AsyncHTTPTransport(proxy=proxy_url, **transport_kwargs),
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
mounts = {}
|
|
85
|
+
if proxy_config.http:
|
|
86
|
+
mounts["http://"] = httpx.AsyncHTTPTransport(proxy=str(proxy_config.http), **transport_kwargs)
|
|
87
|
+
if proxy_config.https:
|
|
88
|
+
mounts["https://"] = httpx.AsyncHTTPTransport(proxy=str(proxy_config.https), **transport_kwargs)
|
|
89
|
+
|
|
90
|
+
return mounts if mounts else None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _build_network_client_kwargs(network_config: NetworkConfig | None) -> dict[str, Any]:
|
|
94
|
+
"""
|
|
95
|
+
Build httpx.AsyncClient kwargs from network configuration.
|
|
96
|
+
|
|
97
|
+
This function creates the appropriate kwargs to pass to httpx.AsyncClient
|
|
98
|
+
based on the provided NetworkConfig, without creating the client itself.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
network_config: Network configuration including TLS, proxy, and timeout settings
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Dictionary of kwargs to pass to httpx.AsyncClient constructor
|
|
105
|
+
"""
|
|
106
|
+
if network_config is None:
|
|
107
|
+
return {}
|
|
108
|
+
|
|
109
|
+
client_kwargs: dict[str, Any] = {}
|
|
110
|
+
|
|
111
|
+
if network_config.tls:
|
|
112
|
+
ssl_context = _build_ssl_context(network_config.tls)
|
|
113
|
+
client_kwargs["verify"] = ssl_context
|
|
114
|
+
|
|
115
|
+
if network_config.proxy:
|
|
116
|
+
mounts = _build_proxy_mounts(network_config.proxy)
|
|
117
|
+
if mounts:
|
|
118
|
+
client_kwargs["mounts"] = mounts
|
|
119
|
+
|
|
120
|
+
if network_config.timeout is not None:
|
|
121
|
+
if isinstance(network_config.timeout, TimeoutConfig):
|
|
122
|
+
# httpx.Timeout requires all four parameters (connect, read, write, pool)
|
|
123
|
+
# to be set explicitly, or a default timeout value
|
|
124
|
+
timeout_kwargs: dict[str, float | None] = {
|
|
125
|
+
"connect": network_config.timeout.connect,
|
|
126
|
+
"read": network_config.timeout.read,
|
|
127
|
+
"write": None,
|
|
128
|
+
"pool": None,
|
|
129
|
+
}
|
|
130
|
+
client_kwargs["timeout"] = httpx.Timeout(**timeout_kwargs)
|
|
131
|
+
else:
|
|
132
|
+
client_kwargs["timeout"] = httpx.Timeout(network_config.timeout)
|
|
133
|
+
|
|
134
|
+
if network_config.headers:
|
|
135
|
+
client_kwargs["headers"] = network_config.headers
|
|
136
|
+
|
|
137
|
+
return client_kwargs
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _extract_client_config(existing_client: httpx.AsyncClient | DefaultAsyncHttpxClient) -> dict[str, Any]:
|
|
141
|
+
"""
|
|
142
|
+
Extract configuration (auth, headers) from an existing http_client.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
existing_client: Existing httpx client (may be DefaultAsyncHttpxClient)
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Dictionary with extracted auth and headers, if available
|
|
149
|
+
"""
|
|
150
|
+
config: dict[str, Any] = {}
|
|
151
|
+
|
|
152
|
+
# Extract from DefaultAsyncHttpxClient
|
|
153
|
+
if isinstance(existing_client, DefaultAsyncHttpxClient):
|
|
154
|
+
underlying_client = existing_client._client # type: ignore[union-attr,attr-defined]
|
|
155
|
+
if hasattr(underlying_client, "_auth"):
|
|
156
|
+
config["auth"] = underlying_client._auth # type: ignore[attr-defined]
|
|
157
|
+
if hasattr(existing_client, "_headers"):
|
|
158
|
+
config["headers"] = existing_client._headers # type: ignore[attr-defined]
|
|
159
|
+
else:
|
|
160
|
+
# Extract from plain httpx.AsyncClient
|
|
161
|
+
if hasattr(existing_client, "_auth"):
|
|
162
|
+
config["auth"] = existing_client._auth # type: ignore[attr-defined]
|
|
163
|
+
if hasattr(existing_client, "_headers"):
|
|
164
|
+
config["headers"] = existing_client._headers # type: ignore[attr-defined]
|
|
165
|
+
|
|
166
|
+
return config
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _merge_network_config_into_client(
|
|
170
|
+
existing_client: httpx.AsyncClient | DefaultAsyncHttpxClient, network_config: NetworkConfig | None
|
|
171
|
+
) -> httpx.AsyncClient | DefaultAsyncHttpxClient:
|
|
172
|
+
"""
|
|
173
|
+
Merge network configuration into an existing http_client.
|
|
174
|
+
|
|
175
|
+
Extracts auth and headers from the existing client, merges with network config,
|
|
176
|
+
and creates a new client with all settings combined.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
existing_client: Existing httpx client (may be DefaultAsyncHttpxClient)
|
|
180
|
+
network_config: Network configuration to apply
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
New client with network config applied, or original client if merge fails
|
|
184
|
+
"""
|
|
185
|
+
if network_config is None:
|
|
186
|
+
return existing_client
|
|
187
|
+
|
|
188
|
+
network_kwargs = _build_network_client_kwargs(network_config)
|
|
189
|
+
if not network_kwargs:
|
|
190
|
+
return existing_client
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
# Extract existing client config (auth, headers)
|
|
194
|
+
existing_config = _extract_client_config(existing_client)
|
|
195
|
+
|
|
196
|
+
# Merge headers: existing headers first, then network config (network takes precedence)
|
|
197
|
+
if existing_config.get("headers") and network_kwargs.get("headers"):
|
|
198
|
+
merged_headers = dict(existing_config["headers"])
|
|
199
|
+
merged_headers.update(network_kwargs["headers"])
|
|
200
|
+
network_kwargs["headers"] = merged_headers
|
|
201
|
+
elif existing_config.get("headers"):
|
|
202
|
+
network_kwargs["headers"] = existing_config["headers"]
|
|
203
|
+
|
|
204
|
+
# Preserve auth from existing client
|
|
205
|
+
if existing_config.get("auth"):
|
|
206
|
+
network_kwargs["auth"] = existing_config["auth"]
|
|
207
|
+
|
|
208
|
+
# Create new client with merged config
|
|
209
|
+
new_client = httpx.AsyncClient(**network_kwargs)
|
|
210
|
+
|
|
211
|
+
# If original was DefaultAsyncHttpxClient, wrap the new client
|
|
212
|
+
if isinstance(existing_client, DefaultAsyncHttpxClient):
|
|
213
|
+
return DefaultAsyncHttpxClient(client=new_client, headers=network_kwargs.get("headers")) # type: ignore[call-arg]
|
|
214
|
+
|
|
215
|
+
return new_client
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.debug(f"Could not merge network config into existing http_client: {e}. Using original client.")
|
|
218
|
+
return existing_client
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def build_http_client(network_config: NetworkConfig | None) -> dict[str, Any]:
|
|
222
|
+
"""
|
|
223
|
+
Build httpx.AsyncClient parameters from network configuration.
|
|
224
|
+
|
|
225
|
+
This function creates the appropriate kwargs to pass to httpx.AsyncClient
|
|
226
|
+
based on the provided NetworkConfig.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
network_config: Network configuration including TLS, proxy, and timeout settings
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Dictionary of kwargs to pass to httpx.AsyncClient constructor,
|
|
233
|
+
wrapped in {"http_client": AsyncClient(...)} for use with AsyncOpenAI
|
|
234
|
+
"""
|
|
235
|
+
network_kwargs = _build_network_client_kwargs(network_config)
|
|
236
|
+
if not network_kwargs:
|
|
237
|
+
return {}
|
|
238
|
+
|
|
239
|
+
return {"http_client": httpx.AsyncClient(**network_kwargs)}
|
|
@@ -30,6 +30,7 @@ from llama_stack_api import (
|
|
|
30
30
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
31
31
|
OpenAIEmbeddingsResponse,
|
|
32
32
|
OpenAIEmbeddingUsage,
|
|
33
|
+
validate_embeddings_input_is_text,
|
|
33
34
|
)
|
|
34
35
|
|
|
35
36
|
logger = get_logger(name=__name__, category="providers::utils")
|
|
@@ -146,6 +147,9 @@ class LiteLLMOpenAIMixin(
|
|
|
146
147
|
self,
|
|
147
148
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
|
148
149
|
) -> OpenAIEmbeddingsResponse:
|
|
150
|
+
# Validate that input contains only text, not token arrays
|
|
151
|
+
validate_embeddings_input_is_text(params)
|
|
152
|
+
|
|
149
153
|
if not self.model_store:
|
|
150
154
|
raise ValueError("Model store is not initialized")
|
|
151
155
|
|
|
@@ -270,6 +274,7 @@ class LiteLLMOpenAIMixin(
|
|
|
270
274
|
top_logprobs=params.top_logprobs,
|
|
271
275
|
top_p=params.top_p,
|
|
272
276
|
user=params.user,
|
|
277
|
+
reasoning_effort=params.reasoning_effort,
|
|
273
278
|
api_key=self.get_api_key(),
|
|
274
279
|
api_base=self.api_base,
|
|
275
280
|
**self._litellm_extra_request_params(params),
|
|
@@ -4,9 +4,10 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Literal
|
|
8
9
|
|
|
9
|
-
from pydantic import BaseModel, Field, SecretStr
|
|
10
|
+
from pydantic import BaseModel, Field, HttpUrl, SecretStr, field_validator, model_validator
|
|
10
11
|
|
|
11
12
|
from llama_stack.log import get_logger
|
|
12
13
|
from llama_stack.providers.utils.inference import (
|
|
@@ -17,6 +18,147 @@ from llama_stack_api import Model, ModelsProtocolPrivate, ModelType, Unsupported
|
|
|
17
18
|
logger = get_logger(name=__name__, category="providers::utils")
|
|
18
19
|
|
|
19
20
|
|
|
21
|
+
class TLSConfig(BaseModel):
|
|
22
|
+
"""TLS/SSL configuration for secure connections."""
|
|
23
|
+
|
|
24
|
+
verify: bool | Path = Field(
|
|
25
|
+
default=True,
|
|
26
|
+
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
|
|
27
|
+
)
|
|
28
|
+
min_version: Literal["TLSv1.2", "TLSv1.3"] | None = Field(
|
|
29
|
+
default=None,
|
|
30
|
+
description="Minimum TLS version to use. Defaults to system default if not specified.",
|
|
31
|
+
)
|
|
32
|
+
ciphers: list[str] | None = Field(
|
|
33
|
+
default=None,
|
|
34
|
+
description="List of allowed cipher suites (e.g., ['ECDHE+AESGCM', 'DHE+AESGCM']).",
|
|
35
|
+
)
|
|
36
|
+
client_cert: Path | None = Field(
|
|
37
|
+
default=None,
|
|
38
|
+
description="Path to client certificate file for mTLS authentication.",
|
|
39
|
+
)
|
|
40
|
+
client_key: Path | None = Field(
|
|
41
|
+
default=None,
|
|
42
|
+
description="Path to client private key file for mTLS authentication.",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
@field_validator("verify", mode="before")
|
|
46
|
+
@classmethod
|
|
47
|
+
def validate_verify(cls, v: bool | str | Path) -> bool | Path:
|
|
48
|
+
if isinstance(v, bool):
|
|
49
|
+
return v
|
|
50
|
+
if isinstance(v, str):
|
|
51
|
+
cert_path = Path(v).expanduser().resolve()
|
|
52
|
+
else:
|
|
53
|
+
cert_path = v.expanduser().resolve()
|
|
54
|
+
if not cert_path.exists():
|
|
55
|
+
raise ValueError(f"TLS certificate file does not exist: {v}")
|
|
56
|
+
if not cert_path.is_file():
|
|
57
|
+
raise ValueError(f"TLS certificate path is not a file: {v}")
|
|
58
|
+
return cert_path
|
|
59
|
+
|
|
60
|
+
@field_validator("client_cert", "client_key", mode="before")
|
|
61
|
+
@classmethod
|
|
62
|
+
def validate_cert_paths(cls, v: str | Path | None) -> Path | None:
|
|
63
|
+
if v is None:
|
|
64
|
+
return None
|
|
65
|
+
if isinstance(v, str):
|
|
66
|
+
cert_path = Path(v).expanduser().resolve()
|
|
67
|
+
else:
|
|
68
|
+
cert_path = v.expanduser().resolve()
|
|
69
|
+
if not cert_path.exists():
|
|
70
|
+
raise ValueError(f"Certificate/key file does not exist: {v}")
|
|
71
|
+
if not cert_path.is_file():
|
|
72
|
+
raise ValueError(f"Certificate/key path is not a file: {v}")
|
|
73
|
+
return cert_path
|
|
74
|
+
|
|
75
|
+
@model_validator(mode="after")
|
|
76
|
+
def validate_mtls_pair(self) -> "TLSConfig":
|
|
77
|
+
if (self.client_cert is None) != (self.client_key is None):
|
|
78
|
+
raise ValueError("Both client_cert and client_key must be provided together for mTLS")
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class ProxyConfig(BaseModel):
|
|
83
|
+
"""Proxy configuration for HTTP connections."""
|
|
84
|
+
|
|
85
|
+
url: HttpUrl | None = Field(
|
|
86
|
+
default=None,
|
|
87
|
+
description="Single proxy URL for all connections (e.g., 'http://proxy.example.com:8080').",
|
|
88
|
+
)
|
|
89
|
+
http: HttpUrl | None = Field(
|
|
90
|
+
default=None,
|
|
91
|
+
description="Proxy URL for HTTP connections.",
|
|
92
|
+
)
|
|
93
|
+
https: HttpUrl | None = Field(
|
|
94
|
+
default=None,
|
|
95
|
+
description="Proxy URL for HTTPS connections.",
|
|
96
|
+
)
|
|
97
|
+
cacert: Path | None = Field(
|
|
98
|
+
default=None,
|
|
99
|
+
description="Path to CA certificate file for verifying the proxy's certificate. Required for proxies in interception mode.",
|
|
100
|
+
)
|
|
101
|
+
no_proxy: list[str] | None = Field(
|
|
102
|
+
default=None,
|
|
103
|
+
description="List of hosts that should bypass the proxy (e.g., ['localhost', '127.0.0.1', '.internal.corp']).",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
@field_validator("cacert", mode="before")
|
|
107
|
+
@classmethod
|
|
108
|
+
def validate_cacert(cls, v: str | Path | None) -> Path | None:
|
|
109
|
+
if v is None:
|
|
110
|
+
return None
|
|
111
|
+
if isinstance(v, str):
|
|
112
|
+
cert_path = Path(v).expanduser().resolve()
|
|
113
|
+
else:
|
|
114
|
+
cert_path = v.expanduser().resolve()
|
|
115
|
+
if not cert_path.exists():
|
|
116
|
+
raise ValueError(f"Proxy CA certificate file does not exist: {v}")
|
|
117
|
+
if not cert_path.is_file():
|
|
118
|
+
raise ValueError(f"Proxy CA certificate path is not a file: {v}")
|
|
119
|
+
return cert_path
|
|
120
|
+
|
|
121
|
+
@model_validator(mode="after")
|
|
122
|
+
def validate_proxy_config(self) -> "ProxyConfig":
|
|
123
|
+
if self.url and (self.http or self.https):
|
|
124
|
+
raise ValueError("Cannot specify both 'url' and 'http'/'https' proxy settings")
|
|
125
|
+
return self
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TimeoutConfig(BaseModel):
|
|
129
|
+
"""Timeout configuration for HTTP connections."""
|
|
130
|
+
|
|
131
|
+
connect: float | None = Field(
|
|
132
|
+
default=None,
|
|
133
|
+
description="Connection timeout in seconds.",
|
|
134
|
+
)
|
|
135
|
+
read: float | None = Field(
|
|
136
|
+
default=None,
|
|
137
|
+
description="Read timeout in seconds.",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class NetworkConfig(BaseModel):
|
|
142
|
+
"""Network configuration for remote provider connections."""
|
|
143
|
+
|
|
144
|
+
tls: TLSConfig | None = Field(
|
|
145
|
+
default=None,
|
|
146
|
+
description="TLS/SSL configuration for secure connections.",
|
|
147
|
+
)
|
|
148
|
+
proxy: ProxyConfig | None = Field(
|
|
149
|
+
default=None,
|
|
150
|
+
description="Proxy configuration for HTTP connections.",
|
|
151
|
+
)
|
|
152
|
+
timeout: float | TimeoutConfig | None = Field(
|
|
153
|
+
default=None,
|
|
154
|
+
description="Timeout configuration. Can be a float (for both connect and read) or a TimeoutConfig object with separate connect and read timeouts.",
|
|
155
|
+
)
|
|
156
|
+
headers: dict[str, str] | None = Field(
|
|
157
|
+
default=None,
|
|
158
|
+
description="Additional HTTP headers to include in all requests.",
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
20
162
|
class RemoteInferenceProviderConfig(BaseModel):
|
|
21
163
|
allowed_models: list[str] | None = Field(
|
|
22
164
|
default=None,
|
|
@@ -31,6 +173,10 @@ class RemoteInferenceProviderConfig(BaseModel):
|
|
|
31
173
|
description="Authentication credential for the provider",
|
|
32
174
|
alias="api_key",
|
|
33
175
|
)
|
|
176
|
+
network: NetworkConfig | None = Field(
|
|
177
|
+
default=None,
|
|
178
|
+
description="Network configuration including TLS, proxy, and timeout settings.",
|
|
179
|
+
)
|
|
34
180
|
|
|
35
181
|
|
|
36
182
|
# TODO: this class is more confusing than useful right now. We need to make it
|
|
@@ -19,6 +19,7 @@ from llama_stack.models.llama.datatypes import (
|
|
|
19
19
|
ToolCall,
|
|
20
20
|
ToolDefinition,
|
|
21
21
|
)
|
|
22
|
+
from llama_stack_api import OpenAIFinishReason
|
|
22
23
|
|
|
23
24
|
logger = get_logger(name=__name__, category="providers::utils")
|
|
24
25
|
|
|
@@ -38,7 +39,7 @@ class OpenAICompatLogprobs(BaseModel):
|
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
class OpenAICompatCompletionChoice(BaseModel):
|
|
41
|
-
finish_reason:
|
|
42
|
+
finish_reason: OpenAIFinishReason | None = None
|
|
42
43
|
text: str | None = None
|
|
43
44
|
delta: OpenAICompatCompletionChoiceDelta | None = None
|
|
44
45
|
logprobs: OpenAICompatLogprobs | None = None
|
|
@@ -10,11 +10,16 @@ from abc import ABC, abstractmethod
|
|
|
10
10
|
from collections.abc import AsyncIterator, Iterable
|
|
11
11
|
from typing import Any
|
|
12
12
|
|
|
13
|
+
import httpx
|
|
13
14
|
from openai import AsyncOpenAI
|
|
14
15
|
from pydantic import BaseModel, ConfigDict
|
|
15
16
|
|
|
16
17
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
|
17
18
|
from llama_stack.log import get_logger
|
|
19
|
+
from llama_stack.providers.utils.inference.http_client import (
|
|
20
|
+
_build_network_client_kwargs,
|
|
21
|
+
_merge_network_config_into_client,
|
|
22
|
+
)
|
|
18
23
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
|
19
24
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
20
25
|
get_stream_options_for_telemetry,
|
|
@@ -34,6 +39,7 @@ from llama_stack_api import (
|
|
|
34
39
|
OpenAIEmbeddingsResponse,
|
|
35
40
|
OpenAIEmbeddingUsage,
|
|
36
41
|
OpenAIMessageParam,
|
|
42
|
+
validate_embeddings_input_is_text,
|
|
37
43
|
)
|
|
38
44
|
|
|
39
45
|
logger = get_logger(name=__name__, category="providers::utils")
|
|
@@ -82,6 +88,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
82
88
|
# Set to False for providers that don't support stream_options (e.g., Ollama, vLLM)
|
|
83
89
|
supports_stream_options: bool = True
|
|
84
90
|
|
|
91
|
+
# Allow subclasses to control whether the provider supports tokenized embeddings input
|
|
92
|
+
# Set to True for providers that support pre-tokenized input (list[int] and list[list[int]])
|
|
93
|
+
supports_tokenized_embeddings_input: bool = False
|
|
94
|
+
|
|
85
95
|
# Embedding model metadata for this provider
|
|
86
96
|
# Can be set by subclasses or instances to provide embedding models
|
|
87
97
|
# Format: {"model_id": {"embedding_dimension": 1536, "context_length": 8192}}
|
|
@@ -121,7 +131,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
121
131
|
Get any extra parameters to pass to the AsyncOpenAI client.
|
|
122
132
|
|
|
123
133
|
Child classes can override this method to provide additional parameters
|
|
124
|
-
such as timeout settings, proxies, etc.
|
|
134
|
+
such as custom http_client, timeout settings, proxies, etc.
|
|
135
|
+
|
|
136
|
+
Note: Network configuration from config.network is automatically applied
|
|
137
|
+
in the client property. This method is for provider-specific customizations.
|
|
125
138
|
|
|
126
139
|
:return: A dictionary of extra parameters
|
|
127
140
|
"""
|
|
@@ -194,6 +207,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
194
207
|
Uses the abstract methods get_api_key() and get_base_url() which must be
|
|
195
208
|
implemented by child classes.
|
|
196
209
|
|
|
210
|
+
Network configuration from config.network is automatically applied.
|
|
197
211
|
Users can also provide the API key via the provider data header, which
|
|
198
212
|
is used instead of any config API key.
|
|
199
213
|
"""
|
|
@@ -205,10 +219,30 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
205
219
|
message += f' Please provide a valid API key in the provider data header, e.g. x-llamastack-provider-data: {{"{self.provider_data_api_key_field}": "<API_KEY>"}}.'
|
|
206
220
|
raise ValueError(message)
|
|
207
221
|
|
|
222
|
+
extra_params = self.get_extra_client_params()
|
|
223
|
+
network_kwargs = _build_network_client_kwargs(self.config.network)
|
|
224
|
+
|
|
225
|
+
# Handle http_client creation/merging:
|
|
226
|
+
# - If get_extra_client_params() provides an http_client (e.g., OCI with custom auth),
|
|
227
|
+
# merge network config into it. The merge behavior:
|
|
228
|
+
# * Preserves auth from get_extra_client_params() (provider-specific auth like OCI signer)
|
|
229
|
+
# * Preserves headers from get_extra_client_params() as base
|
|
230
|
+
# * Applies network config (TLS, proxy, timeout, headers) on top
|
|
231
|
+
# * Network config headers take precedence over provider headers (allows override)
|
|
232
|
+
# - Otherwise, if network config exists, create http_client from it
|
|
233
|
+
# This allows providers with custom auth to still use standard network settings
|
|
234
|
+
if "http_client" in extra_params:
|
|
235
|
+
if network_kwargs:
|
|
236
|
+
extra_params["http_client"] = _merge_network_config_into_client(
|
|
237
|
+
extra_params["http_client"], self.config.network
|
|
238
|
+
)
|
|
239
|
+
elif network_kwargs:
|
|
240
|
+
extra_params["http_client"] = httpx.AsyncClient(**network_kwargs)
|
|
241
|
+
|
|
208
242
|
return AsyncOpenAI(
|
|
209
243
|
api_key=api_key,
|
|
210
244
|
base_url=self.get_base_url(),
|
|
211
|
-
**
|
|
245
|
+
**extra_params,
|
|
212
246
|
)
|
|
213
247
|
|
|
214
248
|
def _get_api_key_from_config_or_provider_data(self) -> str | None:
|
|
@@ -371,6 +405,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
371
405
|
top_logprobs=params.top_logprobs,
|
|
372
406
|
top_p=params.top_p,
|
|
373
407
|
user=params.user,
|
|
408
|
+
reasoning_effort=params.reasoning_effort,
|
|
374
409
|
)
|
|
375
410
|
|
|
376
411
|
if extra_body := params.model_extra:
|
|
@@ -386,6 +421,10 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|
|
386
421
|
"""
|
|
387
422
|
Direct OpenAI embeddings API call.
|
|
388
423
|
"""
|
|
424
|
+
# Validate token array support if provider doesn't support it
|
|
425
|
+
if not self.supports_tokenized_embeddings_input:
|
|
426
|
+
validate_embeddings_input_is_text(params)
|
|
427
|
+
|
|
389
428
|
provider_model_id = await self._get_provider_model_id(params.model)
|
|
390
429
|
self._validate_model_allowed(provider_model_id)
|
|
391
430
|
|