llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +2 -1
- llama_stack/providers/utils/inference/openai_mixin.py +41 -2
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
- llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
import time
|
|
7
8
|
import uuid
|
|
8
9
|
from collections.abc import AsyncIterator
|
|
9
10
|
from typing import Any
|
|
@@ -16,6 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
|
|
16
17
|
from llama_stack_api import (
|
|
17
18
|
AllowedToolsFilter,
|
|
18
19
|
ApprovalFilter,
|
|
20
|
+
Connectors,
|
|
19
21
|
Inference,
|
|
20
22
|
MCPListToolsTool,
|
|
21
23
|
ModelNotFoundError,
|
|
@@ -30,6 +32,7 @@ from llama_stack_api import (
|
|
|
30
32
|
OpenAIChatCompletionToolChoiceFunctionTool,
|
|
31
33
|
OpenAIChoice,
|
|
32
34
|
OpenAIChoiceLogprobs,
|
|
35
|
+
OpenAIFinishReason,
|
|
33
36
|
OpenAIMessageParam,
|
|
34
37
|
OpenAIResponseContentPartOutputText,
|
|
35
38
|
OpenAIResponseContentPartReasoningText,
|
|
@@ -77,6 +80,7 @@ from llama_stack_api import (
|
|
|
77
80
|
OpenAIResponseOutputMessageMCPListTools,
|
|
78
81
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
|
79
82
|
OpenAIResponsePrompt,
|
|
83
|
+
OpenAIResponseReasoning,
|
|
80
84
|
OpenAIResponseText,
|
|
81
85
|
OpenAIResponseUsage,
|
|
82
86
|
OpenAIResponseUsageInputTokensDetails,
|
|
@@ -133,11 +137,15 @@ class StreamingResponseOrchestrator:
|
|
|
133
137
|
instructions: str | None,
|
|
134
138
|
safety_api: Safety | None,
|
|
135
139
|
guardrail_ids: list[str] | None = None,
|
|
140
|
+
connectors_api: Connectors | None = None,
|
|
136
141
|
prompt: OpenAIResponsePrompt | None = None,
|
|
137
142
|
parallel_tool_calls: bool | None = None,
|
|
138
143
|
max_tool_calls: int | None = None,
|
|
144
|
+
reasoning: OpenAIResponseReasoning | None = None,
|
|
145
|
+
max_output_tokens: int | None = None,
|
|
139
146
|
metadata: dict[str, str] | None = None,
|
|
140
147
|
include: list[ResponseItemInclude] | None = None,
|
|
148
|
+
store: bool | None = True,
|
|
141
149
|
):
|
|
142
150
|
self.inference_api = inference_api
|
|
143
151
|
self.ctx = ctx
|
|
@@ -147,6 +155,7 @@ class StreamingResponseOrchestrator:
|
|
|
147
155
|
self.max_infer_iters = max_infer_iters
|
|
148
156
|
self.tool_executor = tool_executor
|
|
149
157
|
self.safety_api = safety_api
|
|
158
|
+
self.connectors_api = connectors_api
|
|
150
159
|
self.guardrail_ids = guardrail_ids or []
|
|
151
160
|
self.prompt = prompt
|
|
152
161
|
# System message that is inserted into the model's context
|
|
@@ -155,8 +164,13 @@ class StreamingResponseOrchestrator:
|
|
|
155
164
|
self.parallel_tool_calls = parallel_tool_calls
|
|
156
165
|
# Max number of total calls to built-in tools that can be processed in a response
|
|
157
166
|
self.max_tool_calls = max_tool_calls
|
|
167
|
+
self.reasoning = reasoning
|
|
168
|
+
# An upper bound for the number of tokens that can be generated for a response
|
|
169
|
+
self.max_output_tokens = max_output_tokens
|
|
158
170
|
self.metadata = metadata
|
|
171
|
+
self.store = store
|
|
159
172
|
self.include = include
|
|
173
|
+
self.store = bool(store) if store is not None else True
|
|
160
174
|
self.sequence_number = 0
|
|
161
175
|
# Store MCP tool mapping that gets built during tool processing
|
|
162
176
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
|
@@ -179,6 +193,8 @@ class StreamingResponseOrchestrator:
|
|
|
179
193
|
self.violation_detected = False
|
|
180
194
|
# Track total calls made to built-in tools
|
|
181
195
|
self.accumulated_builtin_tool_calls = 0
|
|
196
|
+
# Track total output tokens generated across inference calls
|
|
197
|
+
self.accumulated_builtin_output_tokens = 0
|
|
182
198
|
|
|
183
199
|
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
|
184
200
|
"""Create a refusal response to replace streaming content."""
|
|
@@ -191,7 +207,9 @@ class StreamingResponseOrchestrator:
|
|
|
191
207
|
model=self.ctx.model,
|
|
192
208
|
status="completed",
|
|
193
209
|
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
|
210
|
+
max_output_tokens=self.max_output_tokens,
|
|
194
211
|
metadata=self.metadata,
|
|
212
|
+
store=self.store,
|
|
195
213
|
)
|
|
196
214
|
|
|
197
215
|
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
|
@@ -212,8 +230,10 @@ class StreamingResponseOrchestrator:
|
|
|
212
230
|
*,
|
|
213
231
|
error: OpenAIResponseError | None = None,
|
|
214
232
|
) -> OpenAIResponseObject:
|
|
233
|
+
completed_at = int(time.time()) if status == "completed" else None
|
|
215
234
|
return OpenAIResponseObject(
|
|
216
235
|
created_at=self.created_at,
|
|
236
|
+
completed_at=completed_at,
|
|
217
237
|
id=self.response_id,
|
|
218
238
|
model=self.ctx.model,
|
|
219
239
|
object="response",
|
|
@@ -228,7 +248,10 @@ class StreamingResponseOrchestrator:
|
|
|
228
248
|
prompt=self.prompt,
|
|
229
249
|
parallel_tool_calls=self.parallel_tool_calls,
|
|
230
250
|
max_tool_calls=self.max_tool_calls,
|
|
251
|
+
reasoning=self.reasoning,
|
|
252
|
+
max_output_tokens=self.max_output_tokens,
|
|
231
253
|
metadata=self.metadata,
|
|
254
|
+
store=self.store,
|
|
232
255
|
)
|
|
233
256
|
|
|
234
257
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
|
@@ -292,6 +315,22 @@ class StreamingResponseOrchestrator:
|
|
|
292
315
|
|
|
293
316
|
try:
|
|
294
317
|
while True:
|
|
318
|
+
if (
|
|
319
|
+
self.max_output_tokens is not None
|
|
320
|
+
and self.accumulated_builtin_output_tokens >= self.max_output_tokens
|
|
321
|
+
):
|
|
322
|
+
logger.info(
|
|
323
|
+
"Skipping inference call since max_output_tokens reached: "
|
|
324
|
+
f"{self.accumulated_builtin_output_tokens}/{self.max_output_tokens}"
|
|
325
|
+
)
|
|
326
|
+
final_status = "incomplete"
|
|
327
|
+
break
|
|
328
|
+
|
|
329
|
+
remaining_output_tokens = (
|
|
330
|
+
self.max_output_tokens - self.accumulated_builtin_output_tokens
|
|
331
|
+
if self.max_output_tokens is not None
|
|
332
|
+
else None
|
|
333
|
+
)
|
|
295
334
|
# Text is the default response format for chat completion so don't need to pass it
|
|
296
335
|
# (some providers don't support non-empty response_format when tools are present)
|
|
297
336
|
response_format = (
|
|
@@ -311,6 +350,11 @@ class StreamingResponseOrchestrator:
|
|
|
311
350
|
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else None
|
|
312
351
|
)
|
|
313
352
|
|
|
353
|
+
# In OpenAI, parallel_tool_calls is only allowed when 'tools' are specified.
|
|
354
|
+
effective_parallel_tool_calls = (
|
|
355
|
+
self.parallel_tool_calls if effective_tools is not None and len(effective_tools) > 0 else None
|
|
356
|
+
)
|
|
357
|
+
|
|
314
358
|
params = OpenAIChatCompletionRequestWithExtraBody(
|
|
315
359
|
model=self.ctx.model,
|
|
316
360
|
messages=messages,
|
|
@@ -324,6 +368,9 @@ class StreamingResponseOrchestrator:
|
|
|
324
368
|
"include_usage": True,
|
|
325
369
|
},
|
|
326
370
|
logprobs=logprobs,
|
|
371
|
+
parallel_tool_calls=effective_parallel_tool_calls,
|
|
372
|
+
reasoning_effort=self.reasoning.effort if self.reasoning else None,
|
|
373
|
+
max_completion_tokens=remaining_output_tokens,
|
|
327
374
|
)
|
|
328
375
|
completion_result = await self.inference_api.openai_chat_completion(params)
|
|
329
376
|
|
|
@@ -480,23 +527,24 @@ class StreamingResponseOrchestrator:
|
|
|
480
527
|
if not chunk.usage:
|
|
481
528
|
return
|
|
482
529
|
|
|
530
|
+
self.accumulated_builtin_output_tokens += chunk.usage.completion_tokens
|
|
531
|
+
|
|
483
532
|
if self.accumulated_usage is None:
|
|
484
533
|
# Convert from chat completion format to response format
|
|
485
534
|
self.accumulated_usage = OpenAIResponseUsage(
|
|
486
535
|
input_tokens=chunk.usage.prompt_tokens,
|
|
487
536
|
output_tokens=chunk.usage.completion_tokens,
|
|
488
537
|
total_tokens=chunk.usage.total_tokens,
|
|
489
|
-
input_tokens_details=(
|
|
490
|
-
|
|
491
|
-
if chunk.usage.prompt_tokens_details
|
|
492
|
-
else
|
|
538
|
+
input_tokens_details=OpenAIResponseUsageInputTokensDetails(
|
|
539
|
+
cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens
|
|
540
|
+
if chunk.usage.prompt_tokens_details and chunk.usage.prompt_tokens_details.cached_tokens is not None
|
|
541
|
+
else 0
|
|
493
542
|
),
|
|
494
|
-
output_tokens_details=(
|
|
495
|
-
|
|
496
|
-
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
|
497
|
-
)
|
|
543
|
+
output_tokens_details=OpenAIResponseUsageOutputTokensDetails(
|
|
544
|
+
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
|
498
545
|
if chunk.usage.completion_tokens_details
|
|
499
|
-
|
|
546
|
+
and chunk.usage.completion_tokens_details.reasoning_tokens is not None
|
|
547
|
+
else 0
|
|
500
548
|
),
|
|
501
549
|
)
|
|
502
550
|
else:
|
|
@@ -506,17 +554,16 @@ class StreamingResponseOrchestrator:
|
|
|
506
554
|
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
|
|
507
555
|
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
|
|
508
556
|
# Use latest non-null details
|
|
509
|
-
input_tokens_details=(
|
|
510
|
-
|
|
511
|
-
if chunk.usage.prompt_tokens_details
|
|
512
|
-
else self.accumulated_usage.input_tokens_details
|
|
557
|
+
input_tokens_details=OpenAIResponseUsageInputTokensDetails(
|
|
558
|
+
cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens
|
|
559
|
+
if chunk.usage.prompt_tokens_details and chunk.usage.prompt_tokens_details.cached_tokens is not None
|
|
560
|
+
else self.accumulated_usage.input_tokens_details.cached_tokens
|
|
513
561
|
),
|
|
514
|
-
output_tokens_details=(
|
|
515
|
-
|
|
516
|
-
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
|
517
|
-
)
|
|
562
|
+
output_tokens_details=OpenAIResponseUsageOutputTokensDetails(
|
|
563
|
+
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
|
518
564
|
if chunk.usage.completion_tokens_details
|
|
519
|
-
|
|
565
|
+
and chunk.usage.completion_tokens_details.reasoning_tokens is not None
|
|
566
|
+
else self.accumulated_usage.output_tokens_details.reasoning_tokens
|
|
520
567
|
),
|
|
521
568
|
)
|
|
522
569
|
|
|
@@ -652,7 +699,7 @@ class StreamingResponseOrchestrator:
|
|
|
652
699
|
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
|
653
700
|
chunk_created = 0
|
|
654
701
|
chunk_model = ""
|
|
655
|
-
chunk_finish_reason = ""
|
|
702
|
+
chunk_finish_reason: OpenAIFinishReason = "stop"
|
|
656
703
|
chat_response_logprobs = []
|
|
657
704
|
|
|
658
705
|
# Create a placeholder message item for delta events
|
|
@@ -744,9 +791,9 @@ class StreamingResponseOrchestrator:
|
|
|
744
791
|
chunk_finish_reason = chunk_choice.finish_reason
|
|
745
792
|
|
|
746
793
|
# Handle reasoning content if present (non-standard field for o1/o3 models)
|
|
747
|
-
if hasattr(chunk_choice.delta, "
|
|
794
|
+
if hasattr(chunk_choice.delta, "reasoning") and chunk_choice.delta.reasoning:
|
|
748
795
|
async for event in self._handle_reasoning_content_chunk(
|
|
749
|
-
reasoning_content=chunk_choice.delta.
|
|
796
|
+
reasoning_content=chunk_choice.delta.reasoning,
|
|
750
797
|
reasoning_part_emitted=reasoning_part_emitted,
|
|
751
798
|
reasoning_content_index=reasoning_content_index,
|
|
752
799
|
message_item_id=message_item_id,
|
|
@@ -758,7 +805,7 @@ class StreamingResponseOrchestrator:
|
|
|
758
805
|
else:
|
|
759
806
|
yield event
|
|
760
807
|
reasoning_part_emitted = True
|
|
761
|
-
reasoning_text_accumulated.append(chunk_choice.delta.
|
|
808
|
+
reasoning_text_accumulated.append(chunk_choice.delta.reasoning)
|
|
762
809
|
|
|
763
810
|
# Handle refusal content if present
|
|
764
811
|
if chunk_choice.delta.refusal:
|
|
@@ -1175,6 +1222,9 @@ class StreamingResponseOrchestrator:
|
|
|
1175
1222
|
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
|
1176
1223
|
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
|
1177
1224
|
|
|
1225
|
+
# Resolve connector_id to server_url if provided
|
|
1226
|
+
mcp_tool = await resolve_mcp_connector_id(mcp_tool, self.connectors_api)
|
|
1227
|
+
|
|
1178
1228
|
# Emit mcp_list_tools.in_progress
|
|
1179
1229
|
self.sequence_number += 1
|
|
1180
1230
|
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
|
@@ -1489,3 +1539,25 @@ async def _process_tool_choice(
|
|
|
1489
1539
|
tools=tool_choice,
|
|
1490
1540
|
mode="required",
|
|
1491
1541
|
)
|
|
1542
|
+
|
|
1543
|
+
|
|
1544
|
+
async def resolve_mcp_connector_id(
|
|
1545
|
+
mcp_tool: OpenAIResponseInputToolMCP,
|
|
1546
|
+
connectors_api: Connectors,
|
|
1547
|
+
) -> OpenAIResponseInputToolMCP:
|
|
1548
|
+
"""Resolve connector_id to server_url for an MCP tool.
|
|
1549
|
+
|
|
1550
|
+
If the mcp_tool has a connector_id but no server_url, this function
|
|
1551
|
+
looks up the connector and populates the server_url from it.
|
|
1552
|
+
|
|
1553
|
+
Args:
|
|
1554
|
+
mcp_tool: The MCP tool configuration to resolve
|
|
1555
|
+
connectors_api: The connectors API for looking up connectors
|
|
1556
|
+
|
|
1557
|
+
Returns:
|
|
1558
|
+
The mcp_tool with server_url populated (may be same instance if already set)
|
|
1559
|
+
"""
|
|
1560
|
+
if mcp_tool.connector_id and not mcp_tool.server_url:
|
|
1561
|
+
connector = await connectors_api.get_connector(mcp_tool.connector_id)
|
|
1562
|
+
return mcp_tool.model_copy(update={"server_url": connector.url})
|
|
1563
|
+
return mcp_tool
|
|
@@ -12,6 +12,7 @@ from pydantic import BaseModel
|
|
|
12
12
|
|
|
13
13
|
from llama_stack_api import (
|
|
14
14
|
OpenAIChatCompletionToolCall,
|
|
15
|
+
OpenAIFinishReason,
|
|
15
16
|
OpenAIMessageParam,
|
|
16
17
|
OpenAIResponseFormatParam,
|
|
17
18
|
OpenAIResponseInput,
|
|
@@ -52,7 +53,7 @@ class ChatCompletionResult:
|
|
|
52
53
|
tool_calls: dict[int, OpenAIChatCompletionToolCall]
|
|
53
54
|
created: int
|
|
54
55
|
model: str
|
|
55
|
-
finish_reason:
|
|
56
|
+
finish_reason: OpenAIFinishReason
|
|
56
57
|
message_item_id: str # For streaming events
|
|
57
58
|
tool_call_item_ids: dict[int, str] # For streaming events
|
|
58
59
|
content_part_emitted: bool # Tracking state
|
|
@@ -53,6 +53,7 @@ from llama_stack_api import (
|
|
|
53
53
|
OpenAIToolMessageParam,
|
|
54
54
|
OpenAIUserMessageParam,
|
|
55
55
|
ResponseGuardrailSpec,
|
|
56
|
+
RunModerationRequest,
|
|
56
57
|
Safety,
|
|
57
58
|
)
|
|
58
59
|
|
|
@@ -468,7 +469,9 @@ async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids
|
|
|
468
469
|
else:
|
|
469
470
|
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
|
470
471
|
|
|
471
|
-
guardrail_tasks = [
|
|
472
|
+
guardrail_tasks = [
|
|
473
|
+
safety_api.run_moderation(RunModerationRequest(input=messages, model=model_id)) for model_id in model_ids
|
|
474
|
+
]
|
|
472
475
|
responses = await asyncio.gather(*guardrail_tasks)
|
|
473
476
|
|
|
474
477
|
for response in responses:
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
import asyncio
|
|
8
8
|
|
|
9
9
|
from llama_stack.log import get_logger
|
|
10
|
-
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
|
|
10
|
+
from llama_stack_api import OpenAIMessageParam, RunShieldRequest, Safety, SafetyViolation, ViolationLevel
|
|
11
11
|
|
|
12
12
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
|
13
13
|
|
|
@@ -32,7 +32,7 @@ class ShieldRunnerMixin:
|
|
|
32
32
|
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
|
33
33
|
responses = await asyncio.gather(
|
|
34
34
|
*[
|
|
35
|
-
self.safety_api.run_shield(shield_id=identifier, messages=messages
|
|
35
|
+
self.safety_api.run_shield(RunShieldRequest(shield_id=identifier, messages=messages))
|
|
36
36
|
for identifier in identifiers
|
|
37
37
|
]
|
|
38
38
|
)
|
|
@@ -23,6 +23,7 @@ from llama_stack_api import (
|
|
|
23
23
|
BatchObject,
|
|
24
24
|
ConflictError,
|
|
25
25
|
Files,
|
|
26
|
+
GetModelRequest,
|
|
26
27
|
Inference,
|
|
27
28
|
ListBatchesResponse,
|
|
28
29
|
Models,
|
|
@@ -485,7 +486,7 @@ class ReferenceBatchesImpl(Batches):
|
|
|
485
486
|
|
|
486
487
|
if "model" in request_body and isinstance(request_body["model"], str):
|
|
487
488
|
try:
|
|
488
|
-
await self.models_api.get_model(request_body["model"])
|
|
489
|
+
await self.models_api.get_model(GetModelRequest(model_id=request_body["model"]))
|
|
489
490
|
except Exception:
|
|
490
491
|
errors.append(
|
|
491
492
|
BatchError(
|
|
@@ -13,19 +13,25 @@ from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
|
|
13
13
|
from llama_stack_api import (
|
|
14
14
|
Agents,
|
|
15
15
|
Benchmark,
|
|
16
|
-
BenchmarkConfig,
|
|
17
16
|
BenchmarksProtocolPrivate,
|
|
18
17
|
DatasetIO,
|
|
19
18
|
Datasets,
|
|
20
19
|
Eval,
|
|
21
20
|
EvaluateResponse,
|
|
21
|
+
EvaluateRowsRequest,
|
|
22
22
|
Inference,
|
|
23
|
+
IterRowsRequest,
|
|
23
24
|
Job,
|
|
25
|
+
JobCancelRequest,
|
|
26
|
+
JobResultRequest,
|
|
24
27
|
JobStatus,
|
|
28
|
+
JobStatusRequest,
|
|
25
29
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
26
30
|
OpenAICompletionRequestWithExtraBody,
|
|
27
31
|
OpenAISystemMessageParam,
|
|
28
32
|
OpenAIUserMessageParam,
|
|
33
|
+
RunEvalRequest,
|
|
34
|
+
ScoreRequest,
|
|
29
35
|
Scoring,
|
|
30
36
|
)
|
|
31
37
|
|
|
@@ -90,10 +96,9 @@ class MetaReferenceEvalImpl(
|
|
|
90
96
|
|
|
91
97
|
async def run_eval(
|
|
92
98
|
self,
|
|
93
|
-
|
|
94
|
-
benchmark_config: BenchmarkConfig,
|
|
99
|
+
request: RunEvalRequest,
|
|
95
100
|
) -> Job:
|
|
96
|
-
task_def = self.benchmarks[benchmark_id]
|
|
101
|
+
task_def = self.benchmarks[request.benchmark_id]
|
|
97
102
|
dataset_id = task_def.dataset_id
|
|
98
103
|
scoring_functions = task_def.scoring_functions
|
|
99
104
|
|
|
@@ -101,15 +106,18 @@ class MetaReferenceEvalImpl(
|
|
|
101
106
|
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
|
102
107
|
|
|
103
108
|
all_rows = await self.datasetio_api.iterrows(
|
|
104
|
-
|
|
105
|
-
|
|
109
|
+
IterRowsRequest(
|
|
110
|
+
dataset_id=dataset_id,
|
|
111
|
+
limit=(-1 if request.benchmark_config.num_examples is None else request.benchmark_config.num_examples),
|
|
112
|
+
)
|
|
106
113
|
)
|
|
107
|
-
|
|
108
|
-
benchmark_id=benchmark_id,
|
|
114
|
+
eval_rows_request = EvaluateRowsRequest(
|
|
115
|
+
benchmark_id=request.benchmark_id,
|
|
109
116
|
input_rows=all_rows.data,
|
|
110
117
|
scoring_functions=scoring_functions,
|
|
111
|
-
benchmark_config=benchmark_config,
|
|
118
|
+
benchmark_config=request.benchmark_config,
|
|
112
119
|
)
|
|
120
|
+
res = await self.evaluate_rows(eval_rows_request)
|
|
113
121
|
|
|
114
122
|
# TODO: currently needs to wait for generation before returning
|
|
115
123
|
# need job scheduler queue (ray/celery) w/ jobs api
|
|
@@ -118,9 +126,9 @@ class MetaReferenceEvalImpl(
|
|
|
118
126
|
return Job(job_id=job_id, status=JobStatus.completed)
|
|
119
127
|
|
|
120
128
|
async def _run_model_generation(
|
|
121
|
-
self, input_rows: list[dict[str, Any]],
|
|
129
|
+
self, input_rows: list[dict[str, Any]], request: EvaluateRowsRequest
|
|
122
130
|
) -> list[dict[str, Any]]:
|
|
123
|
-
candidate = benchmark_config.eval_candidate
|
|
131
|
+
candidate = request.benchmark_config.eval_candidate
|
|
124
132
|
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
|
125
133
|
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
|
|
126
134
|
|
|
@@ -165,50 +173,50 @@ class MetaReferenceEvalImpl(
|
|
|
165
173
|
|
|
166
174
|
async def evaluate_rows(
|
|
167
175
|
self,
|
|
168
|
-
|
|
169
|
-
input_rows: list[dict[str, Any]],
|
|
170
|
-
scoring_functions: list[str],
|
|
171
|
-
benchmark_config: BenchmarkConfig,
|
|
176
|
+
request: EvaluateRowsRequest,
|
|
172
177
|
) -> EvaluateResponse:
|
|
173
|
-
candidate = benchmark_config.eval_candidate
|
|
178
|
+
candidate = request.benchmark_config.eval_candidate
|
|
174
179
|
# Agent evaluation removed
|
|
175
180
|
if candidate.type == "model":
|
|
176
|
-
generations = await self._run_model_generation(input_rows,
|
|
181
|
+
generations = await self._run_model_generation(request.input_rows, request)
|
|
177
182
|
else:
|
|
178
183
|
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
|
179
184
|
|
|
180
185
|
# scoring with generated_answer
|
|
181
186
|
score_input_rows = [
|
|
182
|
-
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
|
187
|
+
input_r | generated_r for input_r, generated_r in zip(request.input_rows, generations, strict=False)
|
|
183
188
|
]
|
|
184
189
|
|
|
185
|
-
if benchmark_config.scoring_params is not None:
|
|
190
|
+
if request.benchmark_config.scoring_params is not None:
|
|
186
191
|
scoring_functions_dict = {
|
|
187
|
-
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
|
|
188
|
-
for scoring_fn_id in scoring_functions
|
|
192
|
+
scoring_fn_id: request.benchmark_config.scoring_params.get(scoring_fn_id, None)
|
|
193
|
+
for scoring_fn_id in request.scoring_functions
|
|
189
194
|
}
|
|
190
195
|
else:
|
|
191
|
-
scoring_functions_dict = dict.fromkeys(scoring_functions)
|
|
196
|
+
scoring_functions_dict = dict.fromkeys(request.scoring_functions)
|
|
192
197
|
|
|
193
|
-
|
|
194
|
-
input_rows=score_input_rows,
|
|
198
|
+
score_request = ScoreRequest(
|
|
199
|
+
input_rows=score_input_rows,
|
|
200
|
+
scoring_functions=scoring_functions_dict,
|
|
195
201
|
)
|
|
202
|
+
score_response = await self.scoring_api.score(score_request)
|
|
196
203
|
|
|
197
204
|
return EvaluateResponse(generations=generations, scores=score_response.results)
|
|
198
205
|
|
|
199
|
-
async def job_status(self,
|
|
200
|
-
if job_id in self.jobs:
|
|
201
|
-
return Job(job_id=job_id, status=JobStatus.completed)
|
|
206
|
+
async def job_status(self, request: JobStatusRequest) -> Job:
|
|
207
|
+
if request.job_id in self.jobs:
|
|
208
|
+
return Job(job_id=request.job_id, status=JobStatus.completed)
|
|
202
209
|
|
|
203
|
-
raise ValueError(f"Job {job_id} not found")
|
|
210
|
+
raise ValueError(f"Job {request.job_id} not found")
|
|
204
211
|
|
|
205
|
-
async def job_cancel(self,
|
|
212
|
+
async def job_cancel(self, request: JobCancelRequest) -> None:
|
|
206
213
|
raise NotImplementedError("Job cancel is not implemented yet")
|
|
207
214
|
|
|
208
|
-
async def job_result(self,
|
|
209
|
-
|
|
215
|
+
async def job_result(self, request: JobResultRequest) -> EvaluateResponse:
|
|
216
|
+
job_status_request = JobStatusRequest(benchmark_id=request.benchmark_id, job_id=request.job_id)
|
|
217
|
+
job = await self.job_status(job_status_request)
|
|
210
218
|
status = job.status
|
|
211
219
|
if not status or status != JobStatus.completed:
|
|
212
220
|
raise ValueError(f"Job is not completed, Status: {status.value}")
|
|
213
221
|
|
|
214
|
-
return self.jobs[job_id]
|
|
222
|
+
return self.jobs[request.job_id]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
#import <Foundation/Foundation.h>
|
|
2
|
+
|
|
3
|
+
//! Project version number for LocalInference.
|
|
4
|
+
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
|
|
5
|
+
|
|
6
|
+
//! Project version string for LocalInference.
|
|
7
|
+
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
|
|
8
|
+
|
|
9
|
+
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>
|