llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
- llama_stack-0.5.0.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
#
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
|
-
|
|
7
6
|
from typing import Any
|
|
8
7
|
|
|
9
8
|
from llama_stack.log import get_logger
|
|
@@ -11,12 +10,23 @@ from llama_stack_api import (
|
|
|
11
10
|
BenchmarkConfig,
|
|
12
11
|
Eval,
|
|
13
12
|
EvaluateResponse,
|
|
13
|
+
EvaluateRowsRequest,
|
|
14
14
|
Job,
|
|
15
|
+
JobCancelRequest,
|
|
16
|
+
JobResultRequest,
|
|
17
|
+
JobStatusRequest,
|
|
15
18
|
RoutingTable,
|
|
19
|
+
RunEvalRequest,
|
|
20
|
+
ScoreBatchRequest,
|
|
16
21
|
ScoreBatchResponse,
|
|
22
|
+
ScoreRequest,
|
|
17
23
|
ScoreResponse,
|
|
18
24
|
Scoring,
|
|
19
|
-
|
|
25
|
+
resolve_evaluate_rows_request,
|
|
26
|
+
resolve_job_cancel_request,
|
|
27
|
+
resolve_job_result_request,
|
|
28
|
+
resolve_job_status_request,
|
|
29
|
+
resolve_run_eval_request,
|
|
20
30
|
)
|
|
21
31
|
|
|
22
32
|
logger = get_logger(name=__name__, category="core::routers")
|
|
@@ -40,21 +50,22 @@ class ScoringRouter(Scoring):
|
|
|
40
50
|
|
|
41
51
|
async def score_batch(
|
|
42
52
|
self,
|
|
43
|
-
|
|
44
|
-
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
45
|
-
save_results_dataset: bool = False,
|
|
53
|
+
request: ScoreBatchRequest,
|
|
46
54
|
) -> ScoreBatchResponse:
|
|
47
|
-
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
|
55
|
+
logger.debug(f"ScoringRouter.score_batch: {request.dataset_id}")
|
|
48
56
|
res = {}
|
|
49
|
-
for fn_identifier in scoring_functions.keys():
|
|
57
|
+
for fn_identifier in request.scoring_functions.keys():
|
|
50
58
|
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
59
|
+
# Create a request for this specific scoring function
|
|
60
|
+
single_fn_request = ScoreBatchRequest(
|
|
61
|
+
dataset_id=request.dataset_id,
|
|
62
|
+
scoring_functions={fn_identifier: request.scoring_functions[fn_identifier]},
|
|
63
|
+
save_results_dataset=request.save_results_dataset,
|
|
54
64
|
)
|
|
65
|
+
score_response = await provider.score_batch(single_fn_request)
|
|
55
66
|
res.update(score_response.results)
|
|
56
67
|
|
|
57
|
-
if save_results_dataset:
|
|
68
|
+
if request.save_results_dataset:
|
|
58
69
|
raise NotImplementedError("Save results dataset not implemented yet")
|
|
59
70
|
|
|
60
71
|
return ScoreBatchResponse(
|
|
@@ -63,18 +74,19 @@ class ScoringRouter(Scoring):
|
|
|
63
74
|
|
|
64
75
|
async def score(
|
|
65
76
|
self,
|
|
66
|
-
|
|
67
|
-
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
77
|
+
request: ScoreRequest,
|
|
68
78
|
) -> ScoreResponse:
|
|
69
|
-
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
|
79
|
+
logger.debug(f"ScoringRouter.score: {len(request.input_rows)} rows, {len(request.scoring_functions)} functions")
|
|
70
80
|
res = {}
|
|
71
81
|
# look up and map each scoring function to its provider impl
|
|
72
|
-
for fn_identifier in scoring_functions.keys():
|
|
82
|
+
for fn_identifier in request.scoring_functions.keys():
|
|
73
83
|
provider = await self.routing_table.get_provider_impl(fn_identifier)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
84
|
+
# Create a request for this specific scoring function
|
|
85
|
+
single_fn_request = ScoreRequest(
|
|
86
|
+
input_rows=request.input_rows,
|
|
87
|
+
scoring_functions={fn_identifier: request.scoring_functions[fn_identifier]},
|
|
77
88
|
)
|
|
89
|
+
score_response = await provider.score(single_fn_request)
|
|
78
90
|
res.update(score_response.results)
|
|
79
91
|
|
|
80
92
|
return ScoreResponse(results=res)
|
|
@@ -98,61 +110,139 @@ class EvalRouter(Eval):
|
|
|
98
110
|
|
|
99
111
|
async def run_eval(
|
|
100
112
|
self,
|
|
101
|
-
|
|
102
|
-
|
|
113
|
+
request: RunEvalRequest | None = None,
|
|
114
|
+
*,
|
|
115
|
+
benchmark_id: str | None = None,
|
|
116
|
+
benchmark_config: BenchmarkConfig | None = None,
|
|
103
117
|
) -> Job:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
118
|
+
"""Run an evaluation on a benchmark.
|
|
119
|
+
|
|
120
|
+
Supports both new-style (request object) and old-style (individual parameters).
|
|
121
|
+
Old-style usage is deprecated and will emit a DeprecationWarning.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
request: The new-style request object (preferred)
|
|
125
|
+
benchmark_id: (Deprecated) The benchmark ID
|
|
126
|
+
benchmark_config: (Deprecated) The benchmark configuration
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Job object representing the evaluation job
|
|
130
|
+
"""
|
|
131
|
+
resolved_request = resolve_run_eval_request(
|
|
132
|
+
request, benchmark_id=benchmark_id, benchmark_config=benchmark_config
|
|
109
133
|
)
|
|
134
|
+
logger.debug(f"EvalRouter.run_eval: {resolved_request.benchmark_id}")
|
|
135
|
+
provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
|
|
136
|
+
return await provider.run_eval(resolved_request)
|
|
110
137
|
|
|
111
138
|
async def evaluate_rows(
|
|
112
139
|
self,
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
140
|
+
request: EvaluateRowsRequest | None = None,
|
|
141
|
+
*,
|
|
142
|
+
benchmark_id: str | None = None,
|
|
143
|
+
input_rows: list[dict[str, Any]] | None = None,
|
|
144
|
+
scoring_functions: list[str] | None = None,
|
|
145
|
+
benchmark_config: BenchmarkConfig | None = None,
|
|
117
146
|
) -> EvaluateResponse:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
147
|
+
"""Evaluate a list of rows on a benchmark.
|
|
148
|
+
|
|
149
|
+
Supports both new-style (request object) and old-style (individual parameters).
|
|
150
|
+
Old-style usage is deprecated and will emit a DeprecationWarning.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
request: The new-style request object (preferred)
|
|
154
|
+
benchmark_id: (Deprecated) The benchmark ID
|
|
155
|
+
input_rows: (Deprecated) The rows to evaluate
|
|
156
|
+
scoring_functions: (Deprecated) The scoring functions to use
|
|
157
|
+
benchmark_config: (Deprecated) The benchmark configuration
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
EvaluateResponse object containing generations and scores
|
|
161
|
+
"""
|
|
162
|
+
resolved_request = resolve_evaluate_rows_request(
|
|
163
|
+
request,
|
|
121
164
|
benchmark_id=benchmark_id,
|
|
122
165
|
input_rows=input_rows,
|
|
123
166
|
scoring_functions=scoring_functions,
|
|
124
167
|
benchmark_config=benchmark_config,
|
|
125
168
|
)
|
|
169
|
+
logger.debug(
|
|
170
|
+
f"EvalRouter.evaluate_rows: {resolved_request.benchmark_id}, {len(resolved_request.input_rows)} rows"
|
|
171
|
+
)
|
|
172
|
+
provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
|
|
173
|
+
return await provider.evaluate_rows(resolved_request)
|
|
126
174
|
|
|
127
175
|
async def job_status(
|
|
128
176
|
self,
|
|
129
|
-
|
|
130
|
-
|
|
177
|
+
request: JobStatusRequest | None = None,
|
|
178
|
+
*,
|
|
179
|
+
benchmark_id: str | None = None,
|
|
180
|
+
job_id: str | None = None,
|
|
131
181
|
) -> Job:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
182
|
+
"""Get the status of a job.
|
|
183
|
+
|
|
184
|
+
Supports both new-style (request object) and old-style (individual parameters).
|
|
185
|
+
Old-style usage is deprecated and will emit a DeprecationWarning.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
request: The new-style request object (preferred)
|
|
189
|
+
benchmark_id: (Deprecated) The benchmark ID
|
|
190
|
+
job_id: (Deprecated) The job ID
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Job object with the current status
|
|
194
|
+
"""
|
|
195
|
+
resolved_request = resolve_job_status_request(request, benchmark_id=benchmark_id, job_id=job_id)
|
|
196
|
+
logger.debug(f"EvalRouter.job_status: {resolved_request.benchmark_id}, {resolved_request.job_id}")
|
|
197
|
+
provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
|
|
198
|
+
return await provider.job_status(resolved_request)
|
|
135
199
|
|
|
136
200
|
async def job_cancel(
|
|
137
201
|
self,
|
|
138
|
-
|
|
139
|
-
|
|
202
|
+
request: JobCancelRequest | None = None,
|
|
203
|
+
*,
|
|
204
|
+
benchmark_id: str | None = None,
|
|
205
|
+
job_id: str | None = None,
|
|
140
206
|
) -> None:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
207
|
+
"""Cancel a job.
|
|
208
|
+
|
|
209
|
+
Supports both new-style (request object) and old-style (individual parameters).
|
|
210
|
+
Old-style usage is deprecated and will emit a DeprecationWarning.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
request: The new-style request object (preferred)
|
|
214
|
+
benchmark_id: (Deprecated) The benchmark ID
|
|
215
|
+
job_id: (Deprecated) The job ID
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
None
|
|
219
|
+
"""
|
|
220
|
+
resolved_request = resolve_job_cancel_request(request, benchmark_id=benchmark_id, job_id=job_id)
|
|
221
|
+
logger.debug(f"EvalRouter.job_cancel: {resolved_request.benchmark_id}, {resolved_request.job_id}")
|
|
222
|
+
provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
|
|
223
|
+
await provider.job_cancel(resolved_request)
|
|
147
224
|
|
|
148
225
|
async def job_result(
|
|
149
226
|
self,
|
|
150
|
-
|
|
151
|
-
|
|
227
|
+
request: JobResultRequest | None = None,
|
|
228
|
+
*,
|
|
229
|
+
benchmark_id: str | None = None,
|
|
230
|
+
job_id: str | None = None,
|
|
152
231
|
) -> EvaluateResponse:
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
232
|
+
"""Get the result of a job.
|
|
233
|
+
|
|
234
|
+
Supports both new-style (request object) and old-style (individual parameters).
|
|
235
|
+
Old-style usage is deprecated and will emit a DeprecationWarning.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
request: The new-style request object (preferred)
|
|
239
|
+
benchmark_id: (Deprecated) The benchmark ID
|
|
240
|
+
job_id: (Deprecated) The job ID
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
EvaluateResponse object with the job results
|
|
244
|
+
"""
|
|
245
|
+
resolved_request = resolve_job_result_request(request, benchmark_id=benchmark_id, job_id=job_id)
|
|
246
|
+
logger.debug(f"EvalRouter.job_result: {resolved_request.benchmark_id}, {resolved_request.job_id}")
|
|
247
|
+
provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
|
|
248
|
+
return await provider.job_result(resolved_request)
|
|
@@ -20,9 +20,11 @@ from llama_stack.core.request_headers import get_authenticated_user
|
|
|
20
20
|
from llama_stack.log import get_logger
|
|
21
21
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
|
22
22
|
from llama_stack_api import (
|
|
23
|
+
GetChatCompletionRequest,
|
|
23
24
|
HealthResponse,
|
|
24
25
|
HealthStatus,
|
|
25
26
|
Inference,
|
|
27
|
+
ListChatCompletionsRequest,
|
|
26
28
|
ListOpenAIChatCompletionResponse,
|
|
27
29
|
ModelNotFoundError,
|
|
28
30
|
ModelType,
|
|
@@ -45,7 +47,7 @@ from llama_stack_api import (
|
|
|
45
47
|
OpenAIMessageParam,
|
|
46
48
|
OpenAITokenLogProb,
|
|
47
49
|
OpenAITopLogProb,
|
|
48
|
-
|
|
50
|
+
RegisterModelRequest,
|
|
49
51
|
RerankResponse,
|
|
50
52
|
RoutingTable,
|
|
51
53
|
)
|
|
@@ -87,7 +89,14 @@ class InferenceRouter(Inference):
|
|
|
87
89
|
logger.debug(
|
|
88
90
|
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
|
|
89
91
|
)
|
|
90
|
-
|
|
92
|
+
request = RegisterModelRequest(
|
|
93
|
+
model_id=model_id,
|
|
94
|
+
provider_model_id=provider_model_id,
|
|
95
|
+
provider_id=provider_id,
|
|
96
|
+
metadata=metadata,
|
|
97
|
+
model_type=model_type,
|
|
98
|
+
)
|
|
99
|
+
await self.routing_table.register_model(request)
|
|
91
100
|
|
|
92
101
|
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
|
93
102
|
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
|
@@ -229,18 +238,20 @@ class InferenceRouter(Inference):
|
|
|
229
238
|
|
|
230
239
|
async def list_chat_completions(
|
|
231
240
|
self,
|
|
232
|
-
|
|
233
|
-
limit: int | None = 20,
|
|
234
|
-
model: str | None = None,
|
|
235
|
-
order: Order | None = Order.desc,
|
|
241
|
+
request: ListChatCompletionsRequest,
|
|
236
242
|
) -> ListOpenAIChatCompletionResponse:
|
|
237
243
|
if self.store:
|
|
238
|
-
return await self.store.list_chat_completions(
|
|
244
|
+
return await self.store.list_chat_completions(
|
|
245
|
+
after=request.after,
|
|
246
|
+
limit=request.limit,
|
|
247
|
+
model=request.model,
|
|
248
|
+
order=request.order,
|
|
249
|
+
)
|
|
239
250
|
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
|
|
240
251
|
|
|
241
|
-
async def get_chat_completion(self,
|
|
252
|
+
async def get_chat_completion(self, request: GetChatCompletionRequest) -> OpenAICompletionWithInputMessages:
|
|
242
253
|
if self.store:
|
|
243
|
-
return await self.store.get_chat_completion(completion_id)
|
|
254
|
+
return await self.store.get_chat_completion(request.completion_id)
|
|
244
255
|
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
|
|
245
256
|
|
|
246
257
|
async def _nonstream_openai_chat_completion(
|
|
@@ -4,14 +4,22 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
7
|
from opentelemetry import trace
|
|
10
8
|
|
|
11
9
|
from llama_stack.core.datatypes import SafetyConfig
|
|
12
10
|
from llama_stack.log import get_logger
|
|
13
11
|
from llama_stack.telemetry.helpers import safety_request_span_attributes, safety_span_name
|
|
14
|
-
from llama_stack_api import
|
|
12
|
+
from llama_stack_api import (
|
|
13
|
+
ModerationObject,
|
|
14
|
+
RegisterShieldRequest,
|
|
15
|
+
RoutingTable,
|
|
16
|
+
RunModerationRequest,
|
|
17
|
+
RunShieldRequest,
|
|
18
|
+
RunShieldResponse,
|
|
19
|
+
Safety,
|
|
20
|
+
Shield,
|
|
21
|
+
UnregisterShieldRequest,
|
|
22
|
+
)
|
|
15
23
|
|
|
16
24
|
logger = get_logger(name=__name__, category="core::routers")
|
|
17
25
|
tracer = trace.get_tracer(__name__)
|
|
@@ -35,54 +43,38 @@ class SafetyRouter(Safety):
|
|
|
35
43
|
logger.debug("SafetyRouter.shutdown")
|
|
36
44
|
pass
|
|
37
45
|
|
|
38
|
-
async def register_shield(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
provider_shield_id: str | None = None,
|
|
42
|
-
provider_id: str | None = None,
|
|
43
|
-
params: dict[str, Any] | None = None,
|
|
44
|
-
) -> Shield:
|
|
45
|
-
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
|
46
|
-
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
|
46
|
+
async def register_shield(self, request: RegisterShieldRequest) -> Shield:
|
|
47
|
+
logger.debug(f"SafetyRouter.register_shield: {request.shield_id}")
|
|
48
|
+
return await self.routing_table.register_shield(request)
|
|
47
49
|
|
|
48
50
|
async def unregister_shield(self, identifier: str) -> None:
|
|
49
51
|
logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
|
|
50
|
-
return await self.routing_table.unregister_shield(identifier)
|
|
51
|
-
|
|
52
|
-
async def run_shield(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
with tracer.start_as_current_span(name=safety_span_name(shield_id)):
|
|
59
|
-
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
|
60
|
-
provider = await self.routing_table.get_provider_impl(shield_id)
|
|
61
|
-
response = await provider.run_shield(
|
|
62
|
-
shield_id=shield_id,
|
|
63
|
-
messages=messages,
|
|
64
|
-
params=params,
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
safety_request_span_attributes(shield_id, messages, response)
|
|
52
|
+
return await self.routing_table.unregister_shield(UnregisterShieldRequest(identifier=identifier))
|
|
53
|
+
|
|
54
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
55
|
+
with tracer.start_as_current_span(name=safety_span_name(request.shield_id)):
|
|
56
|
+
logger.debug(f"SafetyRouter.run_shield: {request.shield_id}")
|
|
57
|
+
provider = await self.routing_table.get_provider_impl(request.shield_id)
|
|
58
|
+
response = await provider.run_shield(request)
|
|
59
|
+
safety_request_span_attributes(request.shield_id, request.messages, response)
|
|
68
60
|
return response
|
|
69
61
|
|
|
70
|
-
async def run_moderation(self,
|
|
62
|
+
async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
|
|
71
63
|
list_shields_response = await self.routing_table.list_shields()
|
|
72
64
|
shields = list_shields_response.data
|
|
73
65
|
|
|
74
66
|
selected_shield: Shield | None = None
|
|
75
|
-
provider_model: str | None = model
|
|
67
|
+
provider_model: str | None = request.model
|
|
76
68
|
|
|
77
|
-
if model:
|
|
78
|
-
matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
|
|
69
|
+
if request.model:
|
|
70
|
+
matches: list[Shield] = [s for s in shields if request.model == s.provider_resource_id]
|
|
79
71
|
if not matches:
|
|
80
72
|
raise ValueError(
|
|
81
|
-
f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
|
|
73
|
+
f"No shield associated with provider_resource id {request.model}: choose from {[s.provider_resource_id for s in shields]}"
|
|
82
74
|
)
|
|
83
75
|
if len(matches) > 1:
|
|
84
76
|
raise ValueError(
|
|
85
|
-
f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
|
|
77
|
+
f"Multiple shields associated with provider_resource id {request.model}: matched shields {[s.identifier for s in matches]}"
|
|
86
78
|
)
|
|
87
79
|
selected_shield = matches[0]
|
|
88
80
|
else:
|
|
@@ -105,9 +97,5 @@ class SafetyRouter(Safety):
|
|
|
105
97
|
logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
|
|
106
98
|
provider = await self.routing_table.get_provider_impl(shield_id)
|
|
107
99
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
model=provider_model,
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
return response
|
|
100
|
+
provider_request = RunModerationRequest(input=request.input, model=provider_model)
|
|
101
|
+
return await provider.run_moderation(provider_request)
|
|
@@ -39,6 +39,7 @@ from llama_stack_api import (
|
|
|
39
39
|
VectorStoreFileObject,
|
|
40
40
|
VectorStoreFilesListInBatchResponse,
|
|
41
41
|
VectorStoreFileStatus,
|
|
42
|
+
VectorStoreListFilesResponse,
|
|
42
43
|
VectorStoreListResponse,
|
|
43
44
|
VectorStoreObject,
|
|
44
45
|
VectorStoreSearchResponsePage,
|
|
@@ -148,11 +149,12 @@ class VectorIORouter(VectorIO):
|
|
|
148
149
|
self,
|
|
149
150
|
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
|
|
150
151
|
) -> VectorStoreObject:
|
|
151
|
-
# Extract llama-stack-specific parameters from extra_body
|
|
152
|
+
# Extract llama-stack-specific parameters from extra_body or metadata
|
|
152
153
|
extra = params.model_extra or {}
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
154
|
+
metadata = params.metadata or {}
|
|
155
|
+
embedding_model = extra.get("embedding_model", metadata.get("embedding_model"))
|
|
156
|
+
embedding_dimension = extra.get("embedding_dimension", metadata.get("embedding_dimension"))
|
|
157
|
+
provider_id = extra.get("provider_id", metadata.get("provider_id"))
|
|
156
158
|
|
|
157
159
|
# Use default embedding model if not specified
|
|
158
160
|
if (
|
|
@@ -166,8 +168,14 @@ class VectorIORouter(VectorIO):
|
|
|
166
168
|
embedding_model = f"{embedding_provider_id}/{model_id}"
|
|
167
169
|
|
|
168
170
|
if embedding_model is not None and embedding_dimension is None:
|
|
169
|
-
|
|
170
|
-
|
|
171
|
+
if (
|
|
172
|
+
self.vector_stores_config
|
|
173
|
+
and self.vector_stores_config.default_embedding_model is not None
|
|
174
|
+
and self.vector_stores_config.default_embedding_model.embedding_dimensions
|
|
175
|
+
):
|
|
176
|
+
embedding_dimension = self.vector_stores_config.default_embedding_model.embedding_dimensions
|
|
177
|
+
else:
|
|
178
|
+
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
|
171
179
|
# Validate that embedding model exists and is of the correct type
|
|
172
180
|
if embedding_model is not None:
|
|
173
181
|
model = await self.routing_table.get_object_by_identifier("model", embedding_model)
|
|
@@ -376,7 +384,7 @@ class VectorIORouter(VectorIO):
|
|
|
376
384
|
after: str | None = None,
|
|
377
385
|
before: str | None = None,
|
|
378
386
|
filter: VectorStoreFileStatus | None = None,
|
|
379
|
-
) ->
|
|
387
|
+
) -> VectorStoreListFilesResponse:
|
|
380
388
|
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
|
381
389
|
return await self.routing_table.openai_list_files_in_vector_store(
|
|
382
390
|
vector_store_id=vector_store_id,
|
|
@@ -16,6 +16,7 @@ from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProv
|
|
|
16
16
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
|
17
17
|
from llama_stack.log import get_logger
|
|
18
18
|
from llama_stack_api import (
|
|
19
|
+
GetModelRequest,
|
|
19
20
|
ListModelsResponse,
|
|
20
21
|
Model,
|
|
21
22
|
ModelNotFoundError,
|
|
@@ -23,6 +24,8 @@ from llama_stack_api import (
|
|
|
23
24
|
ModelType,
|
|
24
25
|
OpenAIListModelsResponse,
|
|
25
26
|
OpenAIModel,
|
|
27
|
+
RegisterModelRequest,
|
|
28
|
+
UnregisterModelRequest,
|
|
26
29
|
)
|
|
27
30
|
|
|
28
31
|
from .common import CommonRoutingTableImpl, lookup_model
|
|
@@ -171,7 +174,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
|
171
174
|
]
|
|
172
175
|
return OpenAIListModelsResponse(data=openai_models)
|
|
173
176
|
|
|
174
|
-
async def get_model(self,
|
|
177
|
+
async def get_model(self, request_or_model_id: GetModelRequest | str) -> Model:
|
|
178
|
+
# Support both the public Models API (GetModelRequest) and internal ModelStore interface (string)
|
|
179
|
+
if isinstance(request_or_model_id, GetModelRequest):
|
|
180
|
+
model_id = request_or_model_id.model_id
|
|
181
|
+
else:
|
|
182
|
+
model_id = request_or_model_id
|
|
175
183
|
return await lookup_model(self, model_id)
|
|
176
184
|
|
|
177
185
|
async def get_provider_impl(self, model_id: str) -> Any:
|
|
@@ -195,12 +203,28 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
|
195
203
|
|
|
196
204
|
async def register_model(
|
|
197
205
|
self,
|
|
198
|
-
|
|
206
|
+
request: RegisterModelRequest | str | None = None,
|
|
207
|
+
*,
|
|
208
|
+
model_id: str | None = None,
|
|
199
209
|
provider_model_id: str | None = None,
|
|
200
210
|
provider_id: str | None = None,
|
|
201
211
|
metadata: dict[str, Any] | None = None,
|
|
202
212
|
model_type: ModelType | None = None,
|
|
203
213
|
) -> Model:
|
|
214
|
+
# Support both the public Models API (RegisterModelRequest) and legacy parameter-based interface
|
|
215
|
+
if isinstance(request, RegisterModelRequest):
|
|
216
|
+
model_id = request.model_id
|
|
217
|
+
provider_model_id = request.provider_model_id
|
|
218
|
+
provider_id = request.provider_id
|
|
219
|
+
metadata = request.metadata
|
|
220
|
+
model_type = request.model_type
|
|
221
|
+
elif isinstance(request, str):
|
|
222
|
+
# Legacy positional argument: register_model("model-id", ...)
|
|
223
|
+
model_id = request
|
|
224
|
+
|
|
225
|
+
if model_id is None:
|
|
226
|
+
raise ValueError("Either request or model_id must be provided")
|
|
227
|
+
|
|
204
228
|
if provider_id is None:
|
|
205
229
|
# If provider_id not specified, use the only provider if it supports this model
|
|
206
230
|
if len(self.impls_by_provider_id) == 1:
|
|
@@ -229,7 +253,22 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
|
229
253
|
registered_model = await self.register_object(model)
|
|
230
254
|
return registered_model
|
|
231
255
|
|
|
232
|
-
async def unregister_model(
|
|
256
|
+
async def unregister_model(
|
|
257
|
+
self,
|
|
258
|
+
request: UnregisterModelRequest | str | None = None,
|
|
259
|
+
*,
|
|
260
|
+
model_id: str | None = None,
|
|
261
|
+
) -> None:
|
|
262
|
+
# Support both the public Models API (UnregisterModelRequest) and legacy parameter-based interface
|
|
263
|
+
if isinstance(request, UnregisterModelRequest):
|
|
264
|
+
model_id = request.model_id
|
|
265
|
+
elif isinstance(request, str):
|
|
266
|
+
# Legacy positional argument: unregister_model("model-id")
|
|
267
|
+
model_id = request
|
|
268
|
+
|
|
269
|
+
if model_id is None:
|
|
270
|
+
raise ValueError("Either request or model_id must be provided")
|
|
271
|
+
|
|
233
272
|
existing_model = await self.get_model(model_id)
|
|
234
273
|
if existing_model is None:
|
|
235
274
|
raise ModelNotFoundError(model_id)
|
|
@@ -9,12 +9,14 @@ from llama_stack.core.datatypes import (
|
|
|
9
9
|
)
|
|
10
10
|
from llama_stack.log import get_logger
|
|
11
11
|
from llama_stack_api import (
|
|
12
|
+
GetScoringFunctionRequest,
|
|
13
|
+
ListScoringFunctionsRequest,
|
|
12
14
|
ListScoringFunctionsResponse,
|
|
13
|
-
|
|
15
|
+
RegisterScoringFunctionRequest,
|
|
14
16
|
ResourceType,
|
|
15
17
|
ScoringFn,
|
|
16
|
-
ScoringFnParams,
|
|
17
18
|
ScoringFunctions,
|
|
19
|
+
UnregisterScoringFunctionRequest,
|
|
18
20
|
)
|
|
19
21
|
|
|
20
22
|
from .common import CommonRoutingTableImpl
|
|
@@ -23,26 +25,23 @@ logger = get_logger(name=__name__, category="core::routing_tables")
|
|
|
23
25
|
|
|
24
26
|
|
|
25
27
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
26
|
-
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
28
|
+
async def list_scoring_functions(self, request: ListScoringFunctionsRequest) -> ListScoringFunctionsResponse:
|
|
27
29
|
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
|
28
30
|
|
|
29
|
-
async def get_scoring_function(self,
|
|
30
|
-
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
|
31
|
+
async def get_scoring_function(self, request: GetScoringFunctionRequest) -> ScoringFn:
|
|
32
|
+
scoring_fn = await self.get_object_by_identifier("scoring_function", request.scoring_fn_id)
|
|
31
33
|
if scoring_fn is None:
|
|
32
|
-
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
|
34
|
+
raise ValueError(f"Scoring function '{request.scoring_fn_id}' not found")
|
|
33
35
|
return scoring_fn
|
|
34
36
|
|
|
35
37
|
async def register_scoring_function(
|
|
36
38
|
self,
|
|
37
|
-
|
|
38
|
-
description: str,
|
|
39
|
-
return_type: ParamType,
|
|
40
|
-
provider_scoring_fn_id: str | None = None,
|
|
41
|
-
provider_id: str | None = None,
|
|
42
|
-
params: ScoringFnParams | None = None,
|
|
39
|
+
request: RegisterScoringFunctionRequest,
|
|
43
40
|
) -> None:
|
|
41
|
+
provider_scoring_fn_id = request.provider_scoring_fn_id
|
|
44
42
|
if provider_scoring_fn_id is None:
|
|
45
|
-
provider_scoring_fn_id = scoring_fn_id
|
|
43
|
+
provider_scoring_fn_id = request.scoring_fn_id
|
|
44
|
+
provider_id = request.provider_id
|
|
46
45
|
if provider_id is None:
|
|
47
46
|
if len(self.impls_by_provider_id) == 1:
|
|
48
47
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
@@ -51,16 +50,17 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
|
51
50
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
52
51
|
)
|
|
53
52
|
scoring_fn = ScoringFnWithOwner(
|
|
54
|
-
identifier=scoring_fn_id,
|
|
55
|
-
description=description,
|
|
56
|
-
return_type=return_type,
|
|
53
|
+
identifier=request.scoring_fn_id,
|
|
54
|
+
description=request.description,
|
|
55
|
+
return_type=request.return_type,
|
|
57
56
|
provider_resource_id=provider_scoring_fn_id,
|
|
58
57
|
provider_id=provider_id,
|
|
59
|
-
params=params,
|
|
58
|
+
params=request.params,
|
|
60
59
|
)
|
|
61
60
|
scoring_fn.provider_id = provider_id
|
|
62
61
|
await self.register_object(scoring_fn)
|
|
63
62
|
|
|
64
|
-
async def unregister_scoring_function(self,
|
|
65
|
-
|
|
63
|
+
async def unregister_scoring_function(self, request: UnregisterScoringFunctionRequest) -> None:
|
|
64
|
+
get_request = GetScoringFunctionRequest(scoring_fn_id=request.scoring_fn_id)
|
|
65
|
+
existing_scoring_fn = await self.get_scoring_function(get_request)
|
|
66
66
|
await self.unregister_object(existing_scoring_fn)
|