llama-stack 0.2.21__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/apis/benchmarks/benchmarks.py +8 -0
- llama_stack/apis/scoring_functions/scoring_functions.py +8 -0
- llama_stack/cli/stack/_build.py +7 -0
- llama_stack/cli/verify_download.py +7 -10
- llama_stack/core/datatypes.py +10 -3
- llama_stack/core/library_client.py +0 -2
- llama_stack/core/routers/__init__.py +4 -1
- llama_stack/core/routers/inference.py +12 -7
- llama_stack/core/routing_tables/benchmarks.py +4 -0
- llama_stack/core/routing_tables/common.py +4 -0
- llama_stack/core/routing_tables/scoring_functions.py +4 -0
- llama_stack/distributions/ci-tests/build.yaml +1 -0
- llama_stack/distributions/ci-tests/run.yaml +7 -0
- llama_stack/distributions/starter/build.yaml +1 -0
- llama_stack/distributions/starter/run.yaml +7 -0
- llama_stack/distributions/starter/starter.py +18 -0
- llama_stack/distributions/starter-gpu/build.yaml +1 -0
- llama_stack/distributions/starter-gpu/run.yaml +7 -0
- llama_stack/distributions/watsonx/run.yaml +9 -0
- llama_stack/distributions/watsonx/watsonx.py +10 -2
- llama_stack/providers/inline/eval/meta_reference/eval.py +7 -0
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +3 -0
- llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +6 -6
- llama_stack/providers/inline/tool_runtime/rag/memory.py +101 -46
- llama_stack/providers/registry/batches.py +1 -1
- llama_stack/providers/registry/inference.py +22 -11
- llama_stack/providers/registry/scoring.py +1 -1
- llama_stack/providers/remote/eval/nvidia/eval.py +11 -2
- llama_stack/providers/remote/inference/azure/__init__.py +15 -0
- llama_stack/providers/remote/inference/azure/azure.py +64 -0
- llama_stack/providers/remote/inference/azure/config.py +63 -0
- llama_stack/providers/remote/inference/azure/models.py +28 -0
- llama_stack/providers/remote/inference/bedrock/bedrock.py +49 -2
- llama_stack/providers/remote/inference/tgi/tgi.py +43 -15
- llama_stack/providers/remote/inference/together/models.py +70 -44
- llama_stack/providers/remote/inference/together/together.py +79 -130
- llama_stack/providers/remote/inference/vertexai/vertexai.py +29 -4
- llama_stack/providers/remote/inference/vllm/vllm.py +11 -186
- llama_stack/providers/remote/inference/watsonx/config.py +2 -2
- llama_stack/providers/remote/inference/watsonx/watsonx.py +18 -2
- llama_stack/providers/utils/inference/inference_store.py +129 -19
- llama_stack/providers/utils/inference/openai_mixin.py +53 -8
- llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +14 -0
- llama_stack/providers/utils/telemetry/tracing.py +24 -10
- llama_stack/providers/utils/vector_io/vector_utils.py +2 -4
- llama_stack/testing/inference_recorder.py +43 -32
- {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/METADATA +5 -5
- {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/RECORD +52 -48
- {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/WHEEL +0 -0
- {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.2.21.dist-info → llama_stack-0.2.22.dist-info}/top_level.txt +0 -0
|
@@ -93,3 +93,11 @@ class Benchmarks(Protocol):
|
|
|
93
93
|
:param metadata: The metadata to use for the benchmark.
|
|
94
94
|
"""
|
|
95
95
|
...
|
|
96
|
+
|
|
97
|
+
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE")
|
|
98
|
+
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
|
99
|
+
"""Unregister a benchmark.
|
|
100
|
+
|
|
101
|
+
:param benchmark_id: The ID of the benchmark to unregister.
|
|
102
|
+
"""
|
|
103
|
+
...
|
|
@@ -197,3 +197,11 @@ class ScoringFunctions(Protocol):
|
|
|
197
197
|
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
|
198
198
|
"""
|
|
199
199
|
...
|
|
200
|
+
|
|
201
|
+
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE")
|
|
202
|
+
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
|
203
|
+
"""Unregister a scoring function.
|
|
204
|
+
|
|
205
|
+
:param scoring_fn_id: The ID of the scoring function to unregister.
|
|
206
|
+
"""
|
|
207
|
+
...
|
llama_stack/cli/stack/_build.py
CHANGED
|
@@ -45,6 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
|
|
|
45
45
|
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
|
46
46
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
|
47
47
|
from llama_stack.providers.datatypes import Api
|
|
48
|
+
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
|
48
49
|
|
|
49
50
|
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
|
50
51
|
|
|
@@ -294,6 +295,12 @@ def _generate_run_config(
|
|
|
294
295
|
if build_config.external_providers_dir
|
|
295
296
|
else EXTERNAL_PROVIDERS_DIR,
|
|
296
297
|
)
|
|
298
|
+
if not run_config.inference_store:
|
|
299
|
+
run_config.inference_store = SqliteSqlStoreConfig(
|
|
300
|
+
**SqliteSqlStoreConfig.sample_run_config(
|
|
301
|
+
__distro_dir__=(DISTRIBS_BASE_DIR / image_name).as_posix(), db_name="inference_store.db"
|
|
302
|
+
)
|
|
303
|
+
)
|
|
297
304
|
# build providers dict
|
|
298
305
|
provider_registry = get_provider_registry(build_config)
|
|
299
306
|
for api in apis:
|
|
@@ -48,15 +48,12 @@ def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
|
|
48
48
|
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
|
49
49
|
|
|
50
50
|
|
|
51
|
-
def
|
|
52
|
-
|
|
53
|
-
# not for security purposes
|
|
54
|
-
# TODO: switch to SHA256
|
|
55
|
-
md5_hash = hashlib.md5(usedforsecurity=False)
|
|
51
|
+
def calculate_sha256(filepath: Path, chunk_size: int = 8192) -> str:
|
|
52
|
+
sha256_hash = hashlib.sha256()
|
|
56
53
|
with open(filepath, "rb") as f:
|
|
57
54
|
for chunk in iter(lambda: f.read(chunk_size), b""):
|
|
58
|
-
|
|
59
|
-
return
|
|
55
|
+
sha256_hash.update(chunk)
|
|
56
|
+
return sha256_hash.hexdigest()
|
|
60
57
|
|
|
61
58
|
|
|
62
59
|
def load_checksums(checklist_path: Path) -> dict[str, str]:
|
|
@@ -64,10 +61,10 @@ def load_checksums(checklist_path: Path) -> dict[str, str]:
|
|
|
64
61
|
with open(checklist_path) as f:
|
|
65
62
|
for line in f:
|
|
66
63
|
if line.strip():
|
|
67
|
-
|
|
64
|
+
sha256sum, filepath = line.strip().split(" ", 1)
|
|
68
65
|
# Remove leading './' if present
|
|
69
66
|
filepath = filepath.lstrip("./")
|
|
70
|
-
checksums[filepath] =
|
|
67
|
+
checksums[filepath] = sha256sum
|
|
71
68
|
return checksums
|
|
72
69
|
|
|
73
70
|
|
|
@@ -88,7 +85,7 @@ def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -
|
|
|
88
85
|
matches = False
|
|
89
86
|
|
|
90
87
|
if exists:
|
|
91
|
-
actual_hash =
|
|
88
|
+
actual_hash = calculate_sha256(full_path)
|
|
92
89
|
matches = actual_hash == expected_hash
|
|
93
90
|
|
|
94
91
|
results.append(
|
llama_stack/core/datatypes.py
CHANGED
|
@@ -431,6 +431,12 @@ class ServerConfig(BaseModel):
|
|
|
431
431
|
)
|
|
432
432
|
|
|
433
433
|
|
|
434
|
+
class InferenceStoreConfig(BaseModel):
|
|
435
|
+
sql_store_config: SqlStoreConfig
|
|
436
|
+
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
|
|
437
|
+
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
|
438
|
+
|
|
439
|
+
|
|
434
440
|
class StackRunConfig(BaseModel):
|
|
435
441
|
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
|
436
442
|
|
|
@@ -464,11 +470,12 @@ Configuration for the persistence store used by the distribution registry. If no
|
|
|
464
470
|
a default SQLite store will be used.""",
|
|
465
471
|
)
|
|
466
472
|
|
|
467
|
-
inference_store: SqlStoreConfig | None = Field(
|
|
473
|
+
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
|
|
468
474
|
default=None,
|
|
469
475
|
description="""
|
|
470
|
-
Configuration for the persistence store used by the inference API.
|
|
471
|
-
|
|
476
|
+
Configuration for the persistence store used by the inference API. Can be either a
|
|
477
|
+
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
|
|
478
|
+
If not specified, a default SQLite store will be used.""",
|
|
472
479
|
)
|
|
473
480
|
|
|
474
481
|
# registry of "resources" in the distribution
|
|
@@ -10,7 +10,6 @@ import json
|
|
|
10
10
|
import logging # allow-direct-logging
|
|
11
11
|
import os
|
|
12
12
|
import sys
|
|
13
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
14
13
|
from enum import Enum
|
|
15
14
|
from io import BytesIO
|
|
16
15
|
from pathlib import Path
|
|
@@ -148,7 +147,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|
|
148
147
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
|
149
148
|
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
|
150
149
|
)
|
|
151
|
-
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
|
152
150
|
self.provider_data = provider_data
|
|
153
151
|
|
|
154
152
|
self.loop = asyncio.new_event_loop()
|
|
@@ -78,7 +78,10 @@ async def get_auto_router_impl(
|
|
|
78
78
|
|
|
79
79
|
# TODO: move pass configs to routers instead
|
|
80
80
|
if api == Api.inference and run_config.inference_store:
|
|
81
|
-
inference_store = InferenceStore(
|
|
81
|
+
inference_store = InferenceStore(
|
|
82
|
+
config=run_config.inference_store,
|
|
83
|
+
policy=policy,
|
|
84
|
+
)
|
|
82
85
|
await inference_store.initialize()
|
|
83
86
|
api_to_dep_impl["store"] = inference_store
|
|
84
87
|
|
|
@@ -63,7 +63,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|
|
63
63
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
64
64
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
|
65
65
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
66
|
-
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
|
66
|
+
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
|
67
67
|
|
|
68
68
|
logger = get_logger(name=__name__, category="core::routers")
|
|
69
69
|
|
|
@@ -90,6 +90,11 @@ class InferenceRouter(Inference):
|
|
|
90
90
|
|
|
91
91
|
async def shutdown(self) -> None:
|
|
92
92
|
logger.debug("InferenceRouter.shutdown")
|
|
93
|
+
if self.store:
|
|
94
|
+
try:
|
|
95
|
+
await self.store.shutdown()
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
|
93
98
|
|
|
94
99
|
async def register_model(
|
|
95
100
|
self,
|
|
@@ -160,7 +165,7 @@ class InferenceRouter(Inference):
|
|
|
160
165
|
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
|
161
166
|
if self.telemetry:
|
|
162
167
|
for metric in metrics:
|
|
163
|
-
|
|
168
|
+
enqueue_event(metric)
|
|
164
169
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
165
170
|
|
|
166
171
|
async def _count_tokens(
|
|
@@ -431,7 +436,7 @@ class InferenceRouter(Inference):
|
|
|
431
436
|
model=model_obj,
|
|
432
437
|
)
|
|
433
438
|
for metric in metrics:
|
|
434
|
-
|
|
439
|
+
enqueue_event(metric)
|
|
435
440
|
|
|
436
441
|
# these metrics will show up in the client response.
|
|
437
442
|
response.metrics = (
|
|
@@ -537,7 +542,7 @@ class InferenceRouter(Inference):
|
|
|
537
542
|
model=model_obj,
|
|
538
543
|
)
|
|
539
544
|
for metric in metrics:
|
|
540
|
-
|
|
545
|
+
enqueue_event(metric)
|
|
541
546
|
# these metrics will show up in the client response.
|
|
542
547
|
response.metrics = (
|
|
543
548
|
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
|
@@ -664,7 +669,7 @@ class InferenceRouter(Inference):
|
|
|
664
669
|
"completion_tokens",
|
|
665
670
|
"total_tokens",
|
|
666
671
|
]: # Only log completion and total tokens
|
|
667
|
-
|
|
672
|
+
enqueue_event(metric)
|
|
668
673
|
|
|
669
674
|
# Return metrics in response
|
|
670
675
|
async_metrics = [
|
|
@@ -710,7 +715,7 @@ class InferenceRouter(Inference):
|
|
|
710
715
|
)
|
|
711
716
|
for metric in completion_metrics:
|
|
712
717
|
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
|
713
|
-
|
|
718
|
+
enqueue_event(metric)
|
|
714
719
|
|
|
715
720
|
# Return metrics in response
|
|
716
721
|
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
|
@@ -806,7 +811,7 @@ class InferenceRouter(Inference):
|
|
|
806
811
|
model=model,
|
|
807
812
|
)
|
|
808
813
|
for metric in metrics:
|
|
809
|
-
|
|
814
|
+
enqueue_event(metric)
|
|
810
815
|
|
|
811
816
|
yield chunk
|
|
812
817
|
finally:
|
|
@@ -56,3 +56,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|
|
56
56
|
provider_resource_id=provider_benchmark_id,
|
|
57
57
|
)
|
|
58
58
|
await self.register_object(benchmark)
|
|
59
|
+
|
|
60
|
+
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
|
61
|
+
existing_benchmark = await self.get_benchmark(benchmark_id)
|
|
62
|
+
await self.unregister_object(existing_benchmark)
|
|
@@ -64,6 +64,10 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|
|
64
64
|
return await p.unregister_shield(obj.identifier)
|
|
65
65
|
elif api == Api.datasetio:
|
|
66
66
|
return await p.unregister_dataset(obj.identifier)
|
|
67
|
+
elif api == Api.eval:
|
|
68
|
+
return await p.unregister_benchmark(obj.identifier)
|
|
69
|
+
elif api == Api.scoring:
|
|
70
|
+
return await p.unregister_scoring_function(obj.identifier)
|
|
67
71
|
elif api == Api.tool_runtime:
|
|
68
72
|
return await p.unregister_toolgroup(obj.identifier)
|
|
69
73
|
else:
|
|
@@ -60,3 +60,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
|
60
60
|
)
|
|
61
61
|
scoring_fn.provider_id = provider_id
|
|
62
62
|
await self.register_object(scoring_fn)
|
|
63
|
+
|
|
64
|
+
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
|
65
|
+
existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
|
|
66
|
+
await self.unregister_object(existing_scoring_fn)
|
|
@@ -81,6 +81,13 @@ providers:
|
|
|
81
81
|
config:
|
|
82
82
|
url: https://api.sambanova.ai/v1
|
|
83
83
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
|
84
|
+
- provider_id: ${env.AZURE_API_KEY:+azure}
|
|
85
|
+
provider_type: remote::azure
|
|
86
|
+
config:
|
|
87
|
+
api_key: ${env.AZURE_API_KEY:=}
|
|
88
|
+
api_base: ${env.AZURE_API_BASE:=}
|
|
89
|
+
api_version: ${env.AZURE_API_VERSION:=}
|
|
90
|
+
api_type: ${env.AZURE_API_TYPE:=}
|
|
84
91
|
- provider_id: sentence-transformers
|
|
85
92
|
provider_type: inline::sentence-transformers
|
|
86
93
|
vector_io:
|
|
@@ -81,6 +81,13 @@ providers:
|
|
|
81
81
|
config:
|
|
82
82
|
url: https://api.sambanova.ai/v1
|
|
83
83
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
|
84
|
+
- provider_id: ${env.AZURE_API_KEY:+azure}
|
|
85
|
+
provider_type: remote::azure
|
|
86
|
+
config:
|
|
87
|
+
api_key: ${env.AZURE_API_KEY:=}
|
|
88
|
+
api_base: ${env.AZURE_API_BASE:=}
|
|
89
|
+
api_version: ${env.AZURE_API_VERSION:=}
|
|
90
|
+
api_type: ${env.AZURE_API_TYPE:=}
|
|
84
91
|
- provider_id: sentence-transformers
|
|
85
92
|
provider_type: inline::sentence-transformers
|
|
86
93
|
vector_io:
|
|
@@ -59,6 +59,7 @@ ENABLED_INFERENCE_PROVIDERS = [
|
|
|
59
59
|
"cerebras",
|
|
60
60
|
"nvidia",
|
|
61
61
|
"bedrock",
|
|
62
|
+
"azure",
|
|
62
63
|
]
|
|
63
64
|
|
|
64
65
|
INFERENCE_PROVIDER_IDS = {
|
|
@@ -68,6 +69,7 @@ INFERENCE_PROVIDER_IDS = {
|
|
|
68
69
|
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
|
|
69
70
|
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
|
|
70
71
|
"vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}",
|
|
72
|
+
"azure": "${env.AZURE_API_KEY:+azure}",
|
|
71
73
|
}
|
|
72
74
|
|
|
73
75
|
|
|
@@ -277,5 +279,21 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|
|
277
279
|
"http://localhost:11434",
|
|
278
280
|
"Ollama URL",
|
|
279
281
|
),
|
|
282
|
+
"AZURE_API_KEY": (
|
|
283
|
+
"",
|
|
284
|
+
"Azure API Key",
|
|
285
|
+
),
|
|
286
|
+
"AZURE_API_BASE": (
|
|
287
|
+
"",
|
|
288
|
+
"Azure API Base",
|
|
289
|
+
),
|
|
290
|
+
"AZURE_API_VERSION": (
|
|
291
|
+
"",
|
|
292
|
+
"Azure API Version",
|
|
293
|
+
),
|
|
294
|
+
"AZURE_API_TYPE": (
|
|
295
|
+
"azure",
|
|
296
|
+
"Azure API Type",
|
|
297
|
+
),
|
|
280
298
|
},
|
|
281
299
|
)
|
|
@@ -81,6 +81,13 @@ providers:
|
|
|
81
81
|
config:
|
|
82
82
|
url: https://api.sambanova.ai/v1
|
|
83
83
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
|
84
|
+
- provider_id: ${env.AZURE_API_KEY:+azure}
|
|
85
|
+
provider_type: remote::azure
|
|
86
|
+
config:
|
|
87
|
+
api_key: ${env.AZURE_API_KEY:=}
|
|
88
|
+
api_base: ${env.AZURE_API_BASE:=}
|
|
89
|
+
api_version: ${env.AZURE_API_VERSION:=}
|
|
90
|
+
api_type: ${env.AZURE_API_TYPE:=}
|
|
84
91
|
- provider_id: sentence-transformers
|
|
85
92
|
provider_type: inline::sentence-transformers
|
|
86
93
|
vector_io:
|
|
@@ -10,6 +10,7 @@ apis:
|
|
|
10
10
|
- telemetry
|
|
11
11
|
- tool_runtime
|
|
12
12
|
- vector_io
|
|
13
|
+
- files
|
|
13
14
|
providers:
|
|
14
15
|
inference:
|
|
15
16
|
- provider_id: watsonx
|
|
@@ -94,6 +95,14 @@ providers:
|
|
|
94
95
|
provider_type: inline::rag-runtime
|
|
95
96
|
- provider_id: model-context-protocol
|
|
96
97
|
provider_type: remote::model-context-protocol
|
|
98
|
+
files:
|
|
99
|
+
- provider_id: meta-reference-files
|
|
100
|
+
provider_type: inline::localfs
|
|
101
|
+
config:
|
|
102
|
+
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/watsonx/files}
|
|
103
|
+
metadata_store:
|
|
104
|
+
type: sqlite
|
|
105
|
+
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/files_metadata.db
|
|
97
106
|
metadata_store:
|
|
98
107
|
type: sqlite
|
|
99
108
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/registry.db
|
|
@@ -9,6 +9,7 @@ from pathlib import Path
|
|
|
9
9
|
from llama_stack.apis.models import ModelType
|
|
10
10
|
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
|
11
11
|
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
|
12
|
+
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
|
12
13
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
|
13
14
|
SentenceTransformersInferenceConfig,
|
|
14
15
|
)
|
|
@@ -16,7 +17,7 @@ from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
|
|
16
17
|
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
def get_distribution_template() -> DistributionTemplate:
|
|
20
|
+
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
|
20
21
|
providers = {
|
|
21
22
|
"inference": [
|
|
22
23
|
BuildProvider(provider_type="remote::watsonx"),
|
|
@@ -42,6 +43,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|
|
42
43
|
BuildProvider(provider_type="inline::rag-runtime"),
|
|
43
44
|
BuildProvider(provider_type="remote::model-context-protocol"),
|
|
44
45
|
],
|
|
46
|
+
"files": [BuildProvider(provider_type="inline::localfs")],
|
|
45
47
|
}
|
|
46
48
|
|
|
47
49
|
inference_provider = Provider(
|
|
@@ -79,9 +81,14 @@ def get_distribution_template() -> DistributionTemplate:
|
|
|
79
81
|
},
|
|
80
82
|
)
|
|
81
83
|
|
|
84
|
+
files_provider = Provider(
|
|
85
|
+
provider_id="meta-reference-files",
|
|
86
|
+
provider_type="inline::localfs",
|
|
87
|
+
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
|
88
|
+
)
|
|
82
89
|
default_models, _ = get_model_registry(available_models)
|
|
83
90
|
return DistributionTemplate(
|
|
84
|
-
name=
|
|
91
|
+
name=name,
|
|
85
92
|
distro_type="remote_hosted",
|
|
86
93
|
description="Use watsonx for running LLM inference",
|
|
87
94
|
container_image=None,
|
|
@@ -92,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|
|
92
99
|
"run.yaml": RunConfigSettings(
|
|
93
100
|
provider_overrides={
|
|
94
101
|
"inference": [inference_provider, embedding_provider],
|
|
102
|
+
"files": [files_provider],
|
|
95
103
|
},
|
|
96
104
|
default_models=default_models + [embedding_model],
|
|
97
105
|
default_tool_groups=default_tool_groups,
|
|
@@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
|
|
|
75
75
|
)
|
|
76
76
|
self.benchmarks[task_def.identifier] = task_def
|
|
77
77
|
|
|
78
|
+
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
|
79
|
+
if benchmark_id in self.benchmarks:
|
|
80
|
+
del self.benchmarks[benchmark_id]
|
|
81
|
+
|
|
82
|
+
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
|
|
83
|
+
await self.kvstore.delete(key)
|
|
84
|
+
|
|
78
85
|
async def run_eval(
|
|
79
86
|
self,
|
|
80
87
|
benchmark_id: str,
|
|
@@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
|
|
|
63
63
|
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
|
64
64
|
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
|
65
65
|
|
|
66
|
+
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
|
67
|
+
self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
|
|
68
|
+
|
|
66
69
|
async def score_batch(
|
|
67
70
|
self,
|
|
68
71
|
dataset_id: str,
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
from jinja2 import Template
|
|
9
9
|
|
|
10
10
|
from llama_stack.apis.common.content_types import InterleavedContent
|
|
11
|
-
from llama_stack.apis.inference import
|
|
11
|
+
from llama_stack.apis.inference import OpenAIUserMessageParam
|
|
12
12
|
from llama_stack.apis.tools.rag_tool import (
|
|
13
13
|
DefaultRAGQueryGeneratorConfig,
|
|
14
14
|
LLMRAGQueryGeneratorConfig,
|
|
@@ -61,16 +61,16 @@ async def llm_rag_query_generator(
|
|
|
61
61
|
messages = [interleaved_content_as_str(content)]
|
|
62
62
|
|
|
63
63
|
template = Template(config.template)
|
|
64
|
-
|
|
64
|
+
rendered_content: str = template.render({"messages": messages})
|
|
65
65
|
|
|
66
66
|
model = config.model
|
|
67
|
-
message =
|
|
68
|
-
response = await inference_api.
|
|
69
|
-
|
|
67
|
+
message = OpenAIUserMessageParam(content=rendered_content)
|
|
68
|
+
response = await inference_api.openai_chat_completion(
|
|
69
|
+
model=model,
|
|
70
70
|
messages=[message],
|
|
71
71
|
stream=False,
|
|
72
72
|
)
|
|
73
73
|
|
|
74
|
-
query = response.
|
|
74
|
+
query = response.choices[0].message.content
|
|
75
75
|
|
|
76
76
|
return query
|