llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
- llama_stack-0.5.0.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
import time
|
|
7
8
|
import uuid
|
|
8
9
|
from collections.abc import AsyncIterator
|
|
9
10
|
from typing import Any
|
|
@@ -16,6 +17,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
|
|
16
17
|
from llama_stack_api import (
|
|
17
18
|
AllowedToolsFilter,
|
|
18
19
|
ApprovalFilter,
|
|
20
|
+
Connectors,
|
|
19
21
|
Inference,
|
|
20
22
|
MCPListToolsTool,
|
|
21
23
|
ModelNotFoundError,
|
|
@@ -30,6 +32,7 @@ from llama_stack_api import (
|
|
|
30
32
|
OpenAIChatCompletionToolChoiceFunctionTool,
|
|
31
33
|
OpenAIChoice,
|
|
32
34
|
OpenAIChoiceLogprobs,
|
|
35
|
+
OpenAIFinishReason,
|
|
33
36
|
OpenAIMessageParam,
|
|
34
37
|
OpenAIResponseContentPartOutputText,
|
|
35
38
|
OpenAIResponseContentPartReasoningText,
|
|
@@ -77,6 +80,7 @@ from llama_stack_api import (
|
|
|
77
80
|
OpenAIResponseOutputMessageMCPListTools,
|
|
78
81
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
|
79
82
|
OpenAIResponsePrompt,
|
|
83
|
+
OpenAIResponseReasoning,
|
|
80
84
|
OpenAIResponseText,
|
|
81
85
|
OpenAIResponseUsage,
|
|
82
86
|
OpenAIResponseUsageInputTokensDetails,
|
|
@@ -133,11 +137,16 @@ class StreamingResponseOrchestrator:
|
|
|
133
137
|
instructions: str | None,
|
|
134
138
|
safety_api: Safety | None,
|
|
135
139
|
guardrail_ids: list[str] | None = None,
|
|
140
|
+
connectors_api: Connectors | None = None,
|
|
136
141
|
prompt: OpenAIResponsePrompt | None = None,
|
|
137
142
|
parallel_tool_calls: bool | None = None,
|
|
138
143
|
max_tool_calls: int | None = None,
|
|
144
|
+
reasoning: OpenAIResponseReasoning | None = None,
|
|
145
|
+
max_output_tokens: int | None = None,
|
|
146
|
+
safety_identifier: str | None = None,
|
|
139
147
|
metadata: dict[str, str] | None = None,
|
|
140
148
|
include: list[ResponseItemInclude] | None = None,
|
|
149
|
+
store: bool | None = True,
|
|
141
150
|
):
|
|
142
151
|
self.inference_api = inference_api
|
|
143
152
|
self.ctx = ctx
|
|
@@ -147,6 +156,7 @@ class StreamingResponseOrchestrator:
|
|
|
147
156
|
self.max_infer_iters = max_infer_iters
|
|
148
157
|
self.tool_executor = tool_executor
|
|
149
158
|
self.safety_api = safety_api
|
|
159
|
+
self.connectors_api = connectors_api
|
|
150
160
|
self.guardrail_ids = guardrail_ids or []
|
|
151
161
|
self.prompt = prompt
|
|
152
162
|
# System message that is inserted into the model's context
|
|
@@ -155,8 +165,14 @@ class StreamingResponseOrchestrator:
|
|
|
155
165
|
self.parallel_tool_calls = parallel_tool_calls
|
|
156
166
|
# Max number of total calls to built-in tools that can be processed in a response
|
|
157
167
|
self.max_tool_calls = max_tool_calls
|
|
168
|
+
self.reasoning = reasoning
|
|
169
|
+
# An upper bound for the number of tokens that can be generated for a response
|
|
170
|
+
self.max_output_tokens = max_output_tokens
|
|
171
|
+
self.safety_identifier = safety_identifier
|
|
158
172
|
self.metadata = metadata
|
|
173
|
+
self.store = store
|
|
159
174
|
self.include = include
|
|
175
|
+
self.store = bool(store) if store is not None else True
|
|
160
176
|
self.sequence_number = 0
|
|
161
177
|
# Store MCP tool mapping that gets built during tool processing
|
|
162
178
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
|
@@ -179,6 +195,8 @@ class StreamingResponseOrchestrator:
|
|
|
179
195
|
self.violation_detected = False
|
|
180
196
|
# Track total calls made to built-in tools
|
|
181
197
|
self.accumulated_builtin_tool_calls = 0
|
|
198
|
+
# Track total output tokens generated across inference calls
|
|
199
|
+
self.accumulated_builtin_output_tokens = 0
|
|
182
200
|
|
|
183
201
|
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
|
184
202
|
"""Create a refusal response to replace streaming content."""
|
|
@@ -191,7 +209,10 @@ class StreamingResponseOrchestrator:
|
|
|
191
209
|
model=self.ctx.model,
|
|
192
210
|
status="completed",
|
|
193
211
|
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
|
212
|
+
max_output_tokens=self.max_output_tokens,
|
|
213
|
+
safety_identifier=self.safety_identifier,
|
|
194
214
|
metadata=self.metadata,
|
|
215
|
+
store=self.store,
|
|
195
216
|
)
|
|
196
217
|
|
|
197
218
|
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
|
@@ -212,8 +233,10 @@ class StreamingResponseOrchestrator:
|
|
|
212
233
|
*,
|
|
213
234
|
error: OpenAIResponseError | None = None,
|
|
214
235
|
) -> OpenAIResponseObject:
|
|
236
|
+
completed_at = int(time.time()) if status == "completed" else None
|
|
215
237
|
return OpenAIResponseObject(
|
|
216
238
|
created_at=self.created_at,
|
|
239
|
+
completed_at=completed_at,
|
|
217
240
|
id=self.response_id,
|
|
218
241
|
model=self.ctx.model,
|
|
219
242
|
object="response",
|
|
@@ -228,7 +251,11 @@ class StreamingResponseOrchestrator:
|
|
|
228
251
|
prompt=self.prompt,
|
|
229
252
|
parallel_tool_calls=self.parallel_tool_calls,
|
|
230
253
|
max_tool_calls=self.max_tool_calls,
|
|
254
|
+
reasoning=self.reasoning,
|
|
255
|
+
max_output_tokens=self.max_output_tokens,
|
|
256
|
+
safety_identifier=self.safety_identifier,
|
|
231
257
|
metadata=self.metadata,
|
|
258
|
+
store=self.store,
|
|
232
259
|
)
|
|
233
260
|
|
|
234
261
|
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
|
@@ -292,6 +319,22 @@ class StreamingResponseOrchestrator:
|
|
|
292
319
|
|
|
293
320
|
try:
|
|
294
321
|
while True:
|
|
322
|
+
if (
|
|
323
|
+
self.max_output_tokens is not None
|
|
324
|
+
and self.accumulated_builtin_output_tokens >= self.max_output_tokens
|
|
325
|
+
):
|
|
326
|
+
logger.info(
|
|
327
|
+
"Skipping inference call since max_output_tokens reached: "
|
|
328
|
+
f"{self.accumulated_builtin_output_tokens}/{self.max_output_tokens}"
|
|
329
|
+
)
|
|
330
|
+
final_status = "incomplete"
|
|
331
|
+
break
|
|
332
|
+
|
|
333
|
+
remaining_output_tokens = (
|
|
334
|
+
self.max_output_tokens - self.accumulated_builtin_output_tokens
|
|
335
|
+
if self.max_output_tokens is not None
|
|
336
|
+
else None
|
|
337
|
+
)
|
|
295
338
|
# Text is the default response format for chat completion so don't need to pass it
|
|
296
339
|
# (some providers don't support non-empty response_format when tools are present)
|
|
297
340
|
response_format = (
|
|
@@ -311,6 +354,11 @@ class StreamingResponseOrchestrator:
|
|
|
311
354
|
True if self.include and ResponseItemInclude.message_output_text_logprobs in self.include else None
|
|
312
355
|
)
|
|
313
356
|
|
|
357
|
+
# In OpenAI, parallel_tool_calls is only allowed when 'tools' are specified.
|
|
358
|
+
effective_parallel_tool_calls = (
|
|
359
|
+
self.parallel_tool_calls if effective_tools is not None and len(effective_tools) > 0 else None
|
|
360
|
+
)
|
|
361
|
+
|
|
314
362
|
params = OpenAIChatCompletionRequestWithExtraBody(
|
|
315
363
|
model=self.ctx.model,
|
|
316
364
|
messages=messages,
|
|
@@ -324,6 +372,10 @@ class StreamingResponseOrchestrator:
|
|
|
324
372
|
"include_usage": True,
|
|
325
373
|
},
|
|
326
374
|
logprobs=logprobs,
|
|
375
|
+
parallel_tool_calls=effective_parallel_tool_calls,
|
|
376
|
+
reasoning_effort=self.reasoning.effort if self.reasoning else None,
|
|
377
|
+
safety_identifier=self.safety_identifier,
|
|
378
|
+
max_completion_tokens=remaining_output_tokens,
|
|
327
379
|
)
|
|
328
380
|
completion_result = await self.inference_api.openai_chat_completion(params)
|
|
329
381
|
|
|
@@ -480,23 +532,24 @@ class StreamingResponseOrchestrator:
|
|
|
480
532
|
if not chunk.usage:
|
|
481
533
|
return
|
|
482
534
|
|
|
535
|
+
self.accumulated_builtin_output_tokens += chunk.usage.completion_tokens
|
|
536
|
+
|
|
483
537
|
if self.accumulated_usage is None:
|
|
484
538
|
# Convert from chat completion format to response format
|
|
485
539
|
self.accumulated_usage = OpenAIResponseUsage(
|
|
486
540
|
input_tokens=chunk.usage.prompt_tokens,
|
|
487
541
|
output_tokens=chunk.usage.completion_tokens,
|
|
488
542
|
total_tokens=chunk.usage.total_tokens,
|
|
489
|
-
input_tokens_details=(
|
|
490
|
-
|
|
491
|
-
if chunk.usage.prompt_tokens_details
|
|
492
|
-
else
|
|
543
|
+
input_tokens_details=OpenAIResponseUsageInputTokensDetails(
|
|
544
|
+
cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens
|
|
545
|
+
if chunk.usage.prompt_tokens_details and chunk.usage.prompt_tokens_details.cached_tokens is not None
|
|
546
|
+
else 0
|
|
493
547
|
),
|
|
494
|
-
output_tokens_details=(
|
|
495
|
-
|
|
496
|
-
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
|
497
|
-
)
|
|
548
|
+
output_tokens_details=OpenAIResponseUsageOutputTokensDetails(
|
|
549
|
+
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
|
498
550
|
if chunk.usage.completion_tokens_details
|
|
499
|
-
|
|
551
|
+
and chunk.usage.completion_tokens_details.reasoning_tokens is not None
|
|
552
|
+
else 0
|
|
500
553
|
),
|
|
501
554
|
)
|
|
502
555
|
else:
|
|
@@ -506,17 +559,16 @@ class StreamingResponseOrchestrator:
|
|
|
506
559
|
output_tokens=self.accumulated_usage.output_tokens + chunk.usage.completion_tokens,
|
|
507
560
|
total_tokens=self.accumulated_usage.total_tokens + chunk.usage.total_tokens,
|
|
508
561
|
# Use latest non-null details
|
|
509
|
-
input_tokens_details=(
|
|
510
|
-
|
|
511
|
-
if chunk.usage.prompt_tokens_details
|
|
512
|
-
else self.accumulated_usage.input_tokens_details
|
|
562
|
+
input_tokens_details=OpenAIResponseUsageInputTokensDetails(
|
|
563
|
+
cached_tokens=chunk.usage.prompt_tokens_details.cached_tokens
|
|
564
|
+
if chunk.usage.prompt_tokens_details and chunk.usage.prompt_tokens_details.cached_tokens is not None
|
|
565
|
+
else self.accumulated_usage.input_tokens_details.cached_tokens
|
|
513
566
|
),
|
|
514
|
-
output_tokens_details=(
|
|
515
|
-
|
|
516
|
-
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
|
517
|
-
)
|
|
567
|
+
output_tokens_details=OpenAIResponseUsageOutputTokensDetails(
|
|
568
|
+
reasoning_tokens=chunk.usage.completion_tokens_details.reasoning_tokens
|
|
518
569
|
if chunk.usage.completion_tokens_details
|
|
519
|
-
|
|
570
|
+
and chunk.usage.completion_tokens_details.reasoning_tokens is not None
|
|
571
|
+
else self.accumulated_usage.output_tokens_details.reasoning_tokens
|
|
520
572
|
),
|
|
521
573
|
)
|
|
522
574
|
|
|
@@ -652,7 +704,7 @@ class StreamingResponseOrchestrator:
|
|
|
652
704
|
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
|
|
653
705
|
chunk_created = 0
|
|
654
706
|
chunk_model = ""
|
|
655
|
-
chunk_finish_reason = ""
|
|
707
|
+
chunk_finish_reason: OpenAIFinishReason = "stop"
|
|
656
708
|
chat_response_logprobs = []
|
|
657
709
|
|
|
658
710
|
# Create a placeholder message item for delta events
|
|
@@ -744,9 +796,9 @@ class StreamingResponseOrchestrator:
|
|
|
744
796
|
chunk_finish_reason = chunk_choice.finish_reason
|
|
745
797
|
|
|
746
798
|
# Handle reasoning content if present (non-standard field for o1/o3 models)
|
|
747
|
-
if hasattr(chunk_choice.delta, "
|
|
799
|
+
if hasattr(chunk_choice.delta, "reasoning") and chunk_choice.delta.reasoning:
|
|
748
800
|
async for event in self._handle_reasoning_content_chunk(
|
|
749
|
-
reasoning_content=chunk_choice.delta.
|
|
801
|
+
reasoning_content=chunk_choice.delta.reasoning,
|
|
750
802
|
reasoning_part_emitted=reasoning_part_emitted,
|
|
751
803
|
reasoning_content_index=reasoning_content_index,
|
|
752
804
|
message_item_id=message_item_id,
|
|
@@ -758,7 +810,7 @@ class StreamingResponseOrchestrator:
|
|
|
758
810
|
else:
|
|
759
811
|
yield event
|
|
760
812
|
reasoning_part_emitted = True
|
|
761
|
-
reasoning_text_accumulated.append(chunk_choice.delta.
|
|
813
|
+
reasoning_text_accumulated.append(chunk_choice.delta.reasoning)
|
|
762
814
|
|
|
763
815
|
# Handle refusal content if present
|
|
764
816
|
if chunk_choice.delta.refusal:
|
|
@@ -1175,6 +1227,9 @@ class StreamingResponseOrchestrator:
|
|
|
1175
1227
|
"""Process an MCP tool configuration and emit appropriate streaming events."""
|
|
1176
1228
|
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
|
|
1177
1229
|
|
|
1230
|
+
# Resolve connector_id to server_url if provided
|
|
1231
|
+
mcp_tool = await resolve_mcp_connector_id(mcp_tool, self.connectors_api)
|
|
1232
|
+
|
|
1178
1233
|
# Emit mcp_list_tools.in_progress
|
|
1179
1234
|
self.sequence_number += 1
|
|
1180
1235
|
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
|
|
@@ -1489,3 +1544,25 @@ async def _process_tool_choice(
|
|
|
1489
1544
|
tools=tool_choice,
|
|
1490
1545
|
mode="required",
|
|
1491
1546
|
)
|
|
1547
|
+
|
|
1548
|
+
|
|
1549
|
+
async def resolve_mcp_connector_id(
|
|
1550
|
+
mcp_tool: OpenAIResponseInputToolMCP,
|
|
1551
|
+
connectors_api: Connectors,
|
|
1552
|
+
) -> OpenAIResponseInputToolMCP:
|
|
1553
|
+
"""Resolve connector_id to server_url for an MCP tool.
|
|
1554
|
+
|
|
1555
|
+
If the mcp_tool has a connector_id but no server_url, this function
|
|
1556
|
+
looks up the connector and populates the server_url from it.
|
|
1557
|
+
|
|
1558
|
+
Args:
|
|
1559
|
+
mcp_tool: The MCP tool configuration to resolve
|
|
1560
|
+
connectors_api: The connectors API for looking up connectors
|
|
1561
|
+
|
|
1562
|
+
Returns:
|
|
1563
|
+
The mcp_tool with server_url populated (may be same instance if already set)
|
|
1564
|
+
"""
|
|
1565
|
+
if mcp_tool.connector_id and not mcp_tool.server_url:
|
|
1566
|
+
connector = await connectors_api.get_connector(mcp_tool.connector_id)
|
|
1567
|
+
return mcp_tool.model_copy(update={"server_url": connector.url})
|
|
1568
|
+
return mcp_tool
|
|
@@ -12,6 +12,7 @@ from pydantic import BaseModel
|
|
|
12
12
|
|
|
13
13
|
from llama_stack_api import (
|
|
14
14
|
OpenAIChatCompletionToolCall,
|
|
15
|
+
OpenAIFinishReason,
|
|
15
16
|
OpenAIMessageParam,
|
|
16
17
|
OpenAIResponseFormatParam,
|
|
17
18
|
OpenAIResponseInput,
|
|
@@ -52,7 +53,7 @@ class ChatCompletionResult:
|
|
|
52
53
|
tool_calls: dict[int, OpenAIChatCompletionToolCall]
|
|
53
54
|
created: int
|
|
54
55
|
model: str
|
|
55
|
-
finish_reason:
|
|
56
|
+
finish_reason: OpenAIFinishReason
|
|
56
57
|
message_item_id: str # For streaming events
|
|
57
58
|
tool_call_item_ids: dict[int, str] # For streaming events
|
|
58
59
|
content_part_emitted: bool # Tracking state
|
|
@@ -53,6 +53,7 @@ from llama_stack_api import (
|
|
|
53
53
|
OpenAIToolMessageParam,
|
|
54
54
|
OpenAIUserMessageParam,
|
|
55
55
|
ResponseGuardrailSpec,
|
|
56
|
+
RunModerationRequest,
|
|
56
57
|
Safety,
|
|
57
58
|
)
|
|
58
59
|
|
|
@@ -468,7 +469,9 @@ async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids
|
|
|
468
469
|
else:
|
|
469
470
|
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
|
470
471
|
|
|
471
|
-
guardrail_tasks = [
|
|
472
|
+
guardrail_tasks = [
|
|
473
|
+
safety_api.run_moderation(RunModerationRequest(input=messages, model=model_id)) for model_id in model_ids
|
|
474
|
+
]
|
|
472
475
|
responses = await asyncio.gather(*guardrail_tasks)
|
|
473
476
|
|
|
474
477
|
for response in responses:
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
import asyncio
|
|
8
8
|
|
|
9
9
|
from llama_stack.log import get_logger
|
|
10
|
-
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
|
|
10
|
+
from llama_stack_api import OpenAIMessageParam, RunShieldRequest, Safety, SafetyViolation, ViolationLevel
|
|
11
11
|
|
|
12
12
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
|
13
13
|
|
|
@@ -32,7 +32,7 @@ class ShieldRunnerMixin:
|
|
|
32
32
|
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
|
33
33
|
responses = await asyncio.gather(
|
|
34
34
|
*[
|
|
35
|
-
self.safety_api.run_shield(shield_id=identifier, messages=messages
|
|
35
|
+
self.safety_api.run_shield(RunShieldRequest(shield_id=identifier, messages=messages))
|
|
36
36
|
for identifier in identifiers
|
|
37
37
|
]
|
|
38
38
|
)
|
|
@@ -23,6 +23,7 @@ from llama_stack_api import (
|
|
|
23
23
|
BatchObject,
|
|
24
24
|
ConflictError,
|
|
25
25
|
Files,
|
|
26
|
+
GetModelRequest,
|
|
26
27
|
Inference,
|
|
27
28
|
ListBatchesResponse,
|
|
28
29
|
Models,
|
|
@@ -485,7 +486,7 @@ class ReferenceBatchesImpl(Batches):
|
|
|
485
486
|
|
|
486
487
|
if "model" in request_body and isinstance(request_body["model"], str):
|
|
487
488
|
try:
|
|
488
|
-
await self.models_api.get_model(request_body["model"])
|
|
489
|
+
await self.models_api.get_model(GetModelRequest(model_id=request_body["model"]))
|
|
489
490
|
except Exception:
|
|
490
491
|
errors.append(
|
|
491
492
|
BatchError(
|
|
@@ -13,19 +13,25 @@ from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
|
|
13
13
|
from llama_stack_api import (
|
|
14
14
|
Agents,
|
|
15
15
|
Benchmark,
|
|
16
|
-
BenchmarkConfig,
|
|
17
16
|
BenchmarksProtocolPrivate,
|
|
18
17
|
DatasetIO,
|
|
19
18
|
Datasets,
|
|
20
19
|
Eval,
|
|
21
20
|
EvaluateResponse,
|
|
21
|
+
EvaluateRowsRequest,
|
|
22
22
|
Inference,
|
|
23
|
+
IterRowsRequest,
|
|
23
24
|
Job,
|
|
25
|
+
JobCancelRequest,
|
|
26
|
+
JobResultRequest,
|
|
24
27
|
JobStatus,
|
|
28
|
+
JobStatusRequest,
|
|
25
29
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
26
30
|
OpenAICompletionRequestWithExtraBody,
|
|
27
31
|
OpenAISystemMessageParam,
|
|
28
32
|
OpenAIUserMessageParam,
|
|
33
|
+
RunEvalRequest,
|
|
34
|
+
ScoreRequest,
|
|
29
35
|
Scoring,
|
|
30
36
|
)
|
|
31
37
|
|
|
@@ -90,10 +96,9 @@ class MetaReferenceEvalImpl(
|
|
|
90
96
|
|
|
91
97
|
async def run_eval(
|
|
92
98
|
self,
|
|
93
|
-
|
|
94
|
-
benchmark_config: BenchmarkConfig,
|
|
99
|
+
request: RunEvalRequest,
|
|
95
100
|
) -> Job:
|
|
96
|
-
task_def = self.benchmarks[benchmark_id]
|
|
101
|
+
task_def = self.benchmarks[request.benchmark_id]
|
|
97
102
|
dataset_id = task_def.dataset_id
|
|
98
103
|
scoring_functions = task_def.scoring_functions
|
|
99
104
|
|
|
@@ -101,15 +106,18 @@ class MetaReferenceEvalImpl(
|
|
|
101
106
|
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
|
102
107
|
|
|
103
108
|
all_rows = await self.datasetio_api.iterrows(
|
|
104
|
-
|
|
105
|
-
|
|
109
|
+
IterRowsRequest(
|
|
110
|
+
dataset_id=dataset_id,
|
|
111
|
+
limit=(-1 if request.benchmark_config.num_examples is None else request.benchmark_config.num_examples),
|
|
112
|
+
)
|
|
106
113
|
)
|
|
107
|
-
|
|
108
|
-
benchmark_id=benchmark_id,
|
|
114
|
+
eval_rows_request = EvaluateRowsRequest(
|
|
115
|
+
benchmark_id=request.benchmark_id,
|
|
109
116
|
input_rows=all_rows.data,
|
|
110
117
|
scoring_functions=scoring_functions,
|
|
111
|
-
benchmark_config=benchmark_config,
|
|
118
|
+
benchmark_config=request.benchmark_config,
|
|
112
119
|
)
|
|
120
|
+
res = await self.evaluate_rows(eval_rows_request)
|
|
113
121
|
|
|
114
122
|
# TODO: currently needs to wait for generation before returning
|
|
115
123
|
# need job scheduler queue (ray/celery) w/ jobs api
|
|
@@ -118,9 +126,9 @@ class MetaReferenceEvalImpl(
|
|
|
118
126
|
return Job(job_id=job_id, status=JobStatus.completed)
|
|
119
127
|
|
|
120
128
|
async def _run_model_generation(
|
|
121
|
-
self, input_rows: list[dict[str, Any]],
|
|
129
|
+
self, input_rows: list[dict[str, Any]], request: EvaluateRowsRequest
|
|
122
130
|
) -> list[dict[str, Any]]:
|
|
123
|
-
candidate = benchmark_config.eval_candidate
|
|
131
|
+
candidate = request.benchmark_config.eval_candidate
|
|
124
132
|
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
|
125
133
|
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
|
|
126
134
|
|
|
@@ -165,50 +173,50 @@ class MetaReferenceEvalImpl(
|
|
|
165
173
|
|
|
166
174
|
async def evaluate_rows(
|
|
167
175
|
self,
|
|
168
|
-
|
|
169
|
-
input_rows: list[dict[str, Any]],
|
|
170
|
-
scoring_functions: list[str],
|
|
171
|
-
benchmark_config: BenchmarkConfig,
|
|
176
|
+
request: EvaluateRowsRequest,
|
|
172
177
|
) -> EvaluateResponse:
|
|
173
|
-
candidate = benchmark_config.eval_candidate
|
|
178
|
+
candidate = request.benchmark_config.eval_candidate
|
|
174
179
|
# Agent evaluation removed
|
|
175
180
|
if candidate.type == "model":
|
|
176
|
-
generations = await self._run_model_generation(input_rows,
|
|
181
|
+
generations = await self._run_model_generation(request.input_rows, request)
|
|
177
182
|
else:
|
|
178
183
|
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
|
179
184
|
|
|
180
185
|
# scoring with generated_answer
|
|
181
186
|
score_input_rows = [
|
|
182
|
-
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
|
187
|
+
input_r | generated_r for input_r, generated_r in zip(request.input_rows, generations, strict=False)
|
|
183
188
|
]
|
|
184
189
|
|
|
185
|
-
if benchmark_config.scoring_params is not None:
|
|
190
|
+
if request.benchmark_config.scoring_params is not None:
|
|
186
191
|
scoring_functions_dict = {
|
|
187
|
-
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
|
|
188
|
-
for scoring_fn_id in scoring_functions
|
|
192
|
+
scoring_fn_id: request.benchmark_config.scoring_params.get(scoring_fn_id, None)
|
|
193
|
+
for scoring_fn_id in request.scoring_functions
|
|
189
194
|
}
|
|
190
195
|
else:
|
|
191
|
-
scoring_functions_dict = dict.fromkeys(scoring_functions)
|
|
196
|
+
scoring_functions_dict = dict.fromkeys(request.scoring_functions)
|
|
192
197
|
|
|
193
|
-
|
|
194
|
-
input_rows=score_input_rows,
|
|
198
|
+
score_request = ScoreRequest(
|
|
199
|
+
input_rows=score_input_rows,
|
|
200
|
+
scoring_functions=scoring_functions_dict,
|
|
195
201
|
)
|
|
202
|
+
score_response = await self.scoring_api.score(score_request)
|
|
196
203
|
|
|
197
204
|
return EvaluateResponse(generations=generations, scores=score_response.results)
|
|
198
205
|
|
|
199
|
-
async def job_status(self,
|
|
200
|
-
if job_id in self.jobs:
|
|
201
|
-
return Job(job_id=job_id, status=JobStatus.completed)
|
|
206
|
+
async def job_status(self, request: JobStatusRequest) -> Job:
|
|
207
|
+
if request.job_id in self.jobs:
|
|
208
|
+
return Job(job_id=request.job_id, status=JobStatus.completed)
|
|
202
209
|
|
|
203
|
-
raise ValueError(f"Job {job_id} not found")
|
|
210
|
+
raise ValueError(f"Job {request.job_id} not found")
|
|
204
211
|
|
|
205
|
-
async def job_cancel(self,
|
|
212
|
+
async def job_cancel(self, request: JobCancelRequest) -> None:
|
|
206
213
|
raise NotImplementedError("Job cancel is not implemented yet")
|
|
207
214
|
|
|
208
|
-
async def job_result(self,
|
|
209
|
-
|
|
215
|
+
async def job_result(self, request: JobResultRequest) -> EvaluateResponse:
|
|
216
|
+
job_status_request = JobStatusRequest(benchmark_id=request.benchmark_id, job_id=request.job_id)
|
|
217
|
+
job = await self.job_status(job_status_request)
|
|
210
218
|
status = job.status
|
|
211
219
|
if not status or status != JobStatus.completed:
|
|
212
220
|
raise ValueError(f"Job is not completed, Status: {status.value}")
|
|
213
221
|
|
|
214
|
-
return self.jobs[job_id]
|
|
222
|
+
return self.jobs[request.job_id]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
#import <Foundation/Foundation.h>
|
|
2
|
+
|
|
3
|
+
//! Project version number for LocalInference.
|
|
4
|
+
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
|
|
5
|
+
|
|
6
|
+
//! Project version string for LocalInference.
|
|
7
|
+
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
|
|
8
|
+
|
|
9
|
+
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>
|