llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
- llama_stack-0.5.0.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,542 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
import asyncio
|
|
8
|
-
import time
|
|
9
|
-
import uuid
|
|
10
|
-
from collections.abc import AsyncIterator
|
|
11
|
-
|
|
12
|
-
from llama_stack.log import get_logger
|
|
13
|
-
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
|
|
14
|
-
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
|
15
|
-
from llama_stack.models.llama.llama3.prompt_templates import (
|
|
16
|
-
JsonCustomToolGenerator,
|
|
17
|
-
SystemDefaultGenerator,
|
|
18
|
-
)
|
|
19
|
-
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
|
20
|
-
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
|
21
|
-
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
|
22
|
-
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
|
23
|
-
)
|
|
24
|
-
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
|
25
|
-
from llama_stack.models.llama.sku_list import resolve_model
|
|
26
|
-
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
|
27
|
-
from llama_stack.providers.utils.inference.embedding_mixin import (
|
|
28
|
-
SentenceTransformerEmbeddingMixin,
|
|
29
|
-
)
|
|
30
|
-
from llama_stack.providers.utils.inference.model_registry import (
|
|
31
|
-
ModelRegistryHelper,
|
|
32
|
-
build_hf_repo_model_entry,
|
|
33
|
-
)
|
|
34
|
-
from llama_stack_api import (
|
|
35
|
-
InferenceProvider,
|
|
36
|
-
Model,
|
|
37
|
-
ModelsProtocolPrivate,
|
|
38
|
-
ModelType,
|
|
39
|
-
OpenAIAssistantMessageParam,
|
|
40
|
-
OpenAIChatCompletion,
|
|
41
|
-
OpenAIChatCompletionChunk,
|
|
42
|
-
OpenAIChatCompletionRequestWithExtraBody,
|
|
43
|
-
OpenAIChatCompletionUsage,
|
|
44
|
-
OpenAIChoice,
|
|
45
|
-
OpenAICompletion,
|
|
46
|
-
OpenAICompletionRequestWithExtraBody,
|
|
47
|
-
OpenAIUserMessageParam,
|
|
48
|
-
ToolChoice,
|
|
49
|
-
)
|
|
50
|
-
|
|
51
|
-
from .config import MetaReferenceInferenceConfig
|
|
52
|
-
from .generators import LlamaGenerator
|
|
53
|
-
from .model_parallel import LlamaModelParallelGenerator
|
|
54
|
-
|
|
55
|
-
log = get_logger(__name__, category="inference")
|
|
56
|
-
# there's a single model parallel process running serving the model. for now,
|
|
57
|
-
# we don't support multiple concurrent requests to this process.
|
|
58
|
-
SEMAPHORE = asyncio.Semaphore(1)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
|
|
62
|
-
"""Convert OpenAI tool format to ToolDefinition format."""
|
|
63
|
-
# OpenAI tools have function.name and function.parameters
|
|
64
|
-
return ToolDefinition(
|
|
65
|
-
tool_name=tool.function.name,
|
|
66
|
-
description=tool.function.description or "",
|
|
67
|
-
parameters=tool.function.parameters or {},
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _get_tool_choice_prompt(tool_choice, tools) -> str:
|
|
72
|
-
"""Generate prompt text for tool_choice behavior."""
|
|
73
|
-
if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
|
|
74
|
-
return ""
|
|
75
|
-
elif tool_choice == ToolChoice.required or tool_choice == "required":
|
|
76
|
-
return "You MUST use one of the provided functions/tools to answer the user query."
|
|
77
|
-
elif tool_choice == ToolChoice.none or tool_choice == "none":
|
|
78
|
-
return ""
|
|
79
|
-
else:
|
|
80
|
-
# Specific tool specified
|
|
81
|
-
return f"You MUST use the tool `{tool_choice}` to answer the user query."
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def _raw_content_as_str(content) -> str:
|
|
85
|
-
"""Convert RawContent to string for system messages."""
|
|
86
|
-
if isinstance(content, str):
|
|
87
|
-
return content
|
|
88
|
-
elif isinstance(content, RawTextItem):
|
|
89
|
-
return content.text
|
|
90
|
-
elif isinstance(content, list):
|
|
91
|
-
return "\n".join(_raw_content_as_str(c) for c in content)
|
|
92
|
-
else:
|
|
93
|
-
return "<media>"
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def _augment_raw_messages_for_tools_llama_3_1(
|
|
97
|
-
raw_messages: list[RawMessage],
|
|
98
|
-
tools: list,
|
|
99
|
-
tool_choice,
|
|
100
|
-
) -> list[RawMessage]:
|
|
101
|
-
"""Augment raw messages with tool definitions for Llama 3.1 style models."""
|
|
102
|
-
messages = raw_messages.copy()
|
|
103
|
-
existing_system_message = None
|
|
104
|
-
if messages and messages[0].role == "system":
|
|
105
|
-
existing_system_message = messages.pop(0)
|
|
106
|
-
|
|
107
|
-
sys_content = ""
|
|
108
|
-
|
|
109
|
-
# Add tool definitions first (if present)
|
|
110
|
-
if tools:
|
|
111
|
-
# Convert OpenAI tools to ToolDefinitions
|
|
112
|
-
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
|
113
|
-
|
|
114
|
-
# For OpenAI format, all tools are custom (have string names)
|
|
115
|
-
tool_gen = JsonCustomToolGenerator()
|
|
116
|
-
tool_template = tool_gen.gen(tool_definitions)
|
|
117
|
-
sys_content += tool_template.render()
|
|
118
|
-
sys_content += "\n"
|
|
119
|
-
|
|
120
|
-
# Add default system prompt
|
|
121
|
-
default_gen = SystemDefaultGenerator()
|
|
122
|
-
default_template = default_gen.gen()
|
|
123
|
-
sys_content += default_template.render()
|
|
124
|
-
|
|
125
|
-
# Add existing system message if present
|
|
126
|
-
if existing_system_message:
|
|
127
|
-
sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
|
|
128
|
-
|
|
129
|
-
# Add tool choice prompt if needed
|
|
130
|
-
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
|
131
|
-
sys_content += "\n" + tool_choice_prompt
|
|
132
|
-
|
|
133
|
-
# Create new system message
|
|
134
|
-
new_system_message = RawMessage(
|
|
135
|
-
role="system",
|
|
136
|
-
content=[RawTextItem(text=sys_content.strip())],
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
return [new_system_message] + messages
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def _augment_raw_messages_for_tools_llama_4(
|
|
143
|
-
raw_messages: list[RawMessage],
|
|
144
|
-
tools: list,
|
|
145
|
-
tool_choice,
|
|
146
|
-
) -> list[RawMessage]:
|
|
147
|
-
"""Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
|
|
148
|
-
messages = raw_messages.copy()
|
|
149
|
-
existing_system_message = None
|
|
150
|
-
if messages and messages[0].role == "system":
|
|
151
|
-
existing_system_message = messages.pop(0)
|
|
152
|
-
|
|
153
|
-
sys_content = ""
|
|
154
|
-
|
|
155
|
-
# Add tool definitions if present
|
|
156
|
-
if tools:
|
|
157
|
-
# Convert OpenAI tools to ToolDefinitions
|
|
158
|
-
tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
|
|
159
|
-
|
|
160
|
-
# Use python_list format for Llama 4
|
|
161
|
-
tool_gen = PythonListCustomToolGeneratorLlama4()
|
|
162
|
-
system_prompt = None
|
|
163
|
-
if existing_system_message:
|
|
164
|
-
system_prompt = _raw_content_as_str(existing_system_message.content)
|
|
165
|
-
|
|
166
|
-
tool_template = tool_gen.gen(tool_definitions, system_prompt)
|
|
167
|
-
sys_content = tool_template.render()
|
|
168
|
-
elif existing_system_message:
|
|
169
|
-
# No tools, just use existing system message
|
|
170
|
-
sys_content = _raw_content_as_str(existing_system_message.content)
|
|
171
|
-
|
|
172
|
-
# Add tool choice prompt if needed
|
|
173
|
-
if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
|
|
174
|
-
sys_content += "\n" + tool_choice_prompt
|
|
175
|
-
|
|
176
|
-
if sys_content:
|
|
177
|
-
new_system_message = RawMessage(
|
|
178
|
-
role="system",
|
|
179
|
-
content=[RawTextItem(text=sys_content.strip())],
|
|
180
|
-
)
|
|
181
|
-
return [new_system_message] + messages
|
|
182
|
-
|
|
183
|
-
return messages
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def augment_raw_messages_for_tools(
|
|
187
|
-
raw_messages: list[RawMessage],
|
|
188
|
-
params: OpenAIChatCompletionRequestWithExtraBody,
|
|
189
|
-
llama_model,
|
|
190
|
-
) -> list[RawMessage]:
|
|
191
|
-
"""Augment raw messages with tool definitions based on model family."""
|
|
192
|
-
if not params.tools:
|
|
193
|
-
return raw_messages
|
|
194
|
-
|
|
195
|
-
# Determine augmentation strategy based on model family
|
|
196
|
-
if llama_model.model_family == ModelFamily.llama3_1 or (
|
|
197
|
-
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
|
|
198
|
-
):
|
|
199
|
-
# Llama 3.1 and Llama 3.2 multimodal use JSON format
|
|
200
|
-
return _augment_raw_messages_for_tools_llama_3_1(
|
|
201
|
-
raw_messages,
|
|
202
|
-
params.tools,
|
|
203
|
-
params.tool_choice,
|
|
204
|
-
)
|
|
205
|
-
elif llama_model.model_family in (
|
|
206
|
-
ModelFamily.llama3_2,
|
|
207
|
-
ModelFamily.llama3_3,
|
|
208
|
-
ModelFamily.llama4,
|
|
209
|
-
):
|
|
210
|
-
# Llama 3.2/3.3/4 use python_list format
|
|
211
|
-
return _augment_raw_messages_for_tools_llama_4(
|
|
212
|
-
raw_messages,
|
|
213
|
-
params.tools,
|
|
214
|
-
params.tool_choice,
|
|
215
|
-
)
|
|
216
|
-
else:
|
|
217
|
-
# Default to Llama 3.1 style
|
|
218
|
-
return _augment_raw_messages_for_tools_llama_3_1(
|
|
219
|
-
raw_messages,
|
|
220
|
-
params.tools,
|
|
221
|
-
params.tool_choice,
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
|
|
226
|
-
return LlamaGenerator(config, model_id, llama_model)
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
class MetaReferenceInferenceImpl(
|
|
230
|
-
SentenceTransformerEmbeddingMixin,
|
|
231
|
-
InferenceProvider,
|
|
232
|
-
ModelsProtocolPrivate,
|
|
233
|
-
):
|
|
234
|
-
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
|
235
|
-
self.config = config
|
|
236
|
-
self.model_id = None
|
|
237
|
-
self.llama_model = None
|
|
238
|
-
|
|
239
|
-
async def initialize(self) -> None:
|
|
240
|
-
pass
|
|
241
|
-
|
|
242
|
-
async def shutdown(self) -> None:
|
|
243
|
-
if self.config.create_distributed_process_group:
|
|
244
|
-
self.generator.stop()
|
|
245
|
-
|
|
246
|
-
async def openai_completion(
|
|
247
|
-
self,
|
|
248
|
-
params: OpenAICompletionRequestWithExtraBody,
|
|
249
|
-
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
|
250
|
-
raise NotImplementedError("OpenAI completion not supported by meta reference provider")
|
|
251
|
-
|
|
252
|
-
async def should_refresh_models(self) -> bool:
|
|
253
|
-
return False
|
|
254
|
-
|
|
255
|
-
async def list_models(self) -> list[Model] | None:
|
|
256
|
-
return None
|
|
257
|
-
|
|
258
|
-
async def unregister_model(self, model_id: str) -> None:
|
|
259
|
-
pass
|
|
260
|
-
|
|
261
|
-
async def register_model(self, model: Model) -> Model:
|
|
262
|
-
llama_model = (
|
|
263
|
-
resolve_model(model.metadata["llama_model"])
|
|
264
|
-
if "llama_model" in model.metadata
|
|
265
|
-
else resolve_model(model.identifier)
|
|
266
|
-
)
|
|
267
|
-
if llama_model is None:
|
|
268
|
-
raise ValueError(
|
|
269
|
-
"Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
self.model_registry_helper = ModelRegistryHelper(
|
|
273
|
-
[
|
|
274
|
-
build_hf_repo_model_entry(
|
|
275
|
-
llama_model.descriptor(),
|
|
276
|
-
llama_model.core_model_id.value,
|
|
277
|
-
)
|
|
278
|
-
],
|
|
279
|
-
)
|
|
280
|
-
model = await self.model_registry_helper.register_model(model)
|
|
281
|
-
|
|
282
|
-
if model.model_type == ModelType.embedding:
|
|
283
|
-
self._load_sentence_transformer_model(model.provider_resource_id)
|
|
284
|
-
|
|
285
|
-
# TODO: what is this?! you can't really specify skipping via model metadata
|
|
286
|
-
# kill this madness
|
|
287
|
-
if "skip_load" in model.metadata and model.metadata["skip_load"]:
|
|
288
|
-
return model
|
|
289
|
-
|
|
290
|
-
await self.load_model(model.identifier, llama_model)
|
|
291
|
-
return model
|
|
292
|
-
|
|
293
|
-
async def load_model(self, model_id, llama_model) -> None:
|
|
294
|
-
log.info(f"Loading model `{model_id}`")
|
|
295
|
-
|
|
296
|
-
builder_params = [self.config, model_id, llama_model]
|
|
297
|
-
|
|
298
|
-
if self.config.create_distributed_process_group:
|
|
299
|
-
self.generator = LlamaModelParallelGenerator(
|
|
300
|
-
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
|
301
|
-
builder_fn=llama_builder_fn,
|
|
302
|
-
builder_params=builder_params,
|
|
303
|
-
formatter=(
|
|
304
|
-
Llama4ChatFormat(Llama4Tokenizer.get_instance())
|
|
305
|
-
if llama_model.model_family == ModelFamily.llama4
|
|
306
|
-
else Llama3ChatFormat(Llama3Tokenizer.get_instance())
|
|
307
|
-
),
|
|
308
|
-
)
|
|
309
|
-
self.generator.start()
|
|
310
|
-
else:
|
|
311
|
-
self.generator = llama_builder_fn(*builder_params)
|
|
312
|
-
|
|
313
|
-
self.model_id = model_id
|
|
314
|
-
self.llama_model = llama_model
|
|
315
|
-
|
|
316
|
-
log.info("Warming up...")
|
|
317
|
-
|
|
318
|
-
await self.openai_chat_completion(
|
|
319
|
-
params=OpenAIChatCompletionRequestWithExtraBody(
|
|
320
|
-
model=model_id,
|
|
321
|
-
messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
|
|
322
|
-
max_tokens=20,
|
|
323
|
-
)
|
|
324
|
-
)
|
|
325
|
-
log.info("Warmed up!")
|
|
326
|
-
|
|
327
|
-
def check_model(self, request) -> None:
|
|
328
|
-
if self.model_id is None or self.llama_model is None:
|
|
329
|
-
raise RuntimeError(
|
|
330
|
-
"No available model yet, please register your requested model or add your model in the resources first"
|
|
331
|
-
)
|
|
332
|
-
elif request.model != self.model_id:
|
|
333
|
-
raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
|
|
334
|
-
|
|
335
|
-
async def openai_chat_completion(
|
|
336
|
-
self,
|
|
337
|
-
params: OpenAIChatCompletionRequestWithExtraBody,
|
|
338
|
-
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
339
|
-
self.check_model(params)
|
|
340
|
-
|
|
341
|
-
# Convert OpenAI messages to RawMessages
|
|
342
|
-
from llama_stack.models.llama.datatypes import StopReason
|
|
343
|
-
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
344
|
-
convert_openai_message_to_raw_message,
|
|
345
|
-
decode_assistant_message,
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
|
|
349
|
-
|
|
350
|
-
# Augment messages with tool definitions if tools are present
|
|
351
|
-
raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
|
|
352
|
-
|
|
353
|
-
# Call generator's chat_completion method (works for both single-GPU and model-parallel)
|
|
354
|
-
if isinstance(self.generator, LlamaGenerator):
|
|
355
|
-
generator = self.generator.chat_completion(params, raw_messages)
|
|
356
|
-
else:
|
|
357
|
-
# Model parallel: submit task to process group
|
|
358
|
-
generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
|
|
359
|
-
|
|
360
|
-
# Check if streaming is requested
|
|
361
|
-
if params.stream:
|
|
362
|
-
return self._stream_chat_completion(generator, params)
|
|
363
|
-
|
|
364
|
-
# Non-streaming: collect all generated text
|
|
365
|
-
generated_text = ""
|
|
366
|
-
for result_batch in generator:
|
|
367
|
-
for result in result_batch:
|
|
368
|
-
if not result.ignore_token and result.source == "output":
|
|
369
|
-
generated_text += result.text
|
|
370
|
-
|
|
371
|
-
# Decode assistant message to extract tool calls and determine stop_reason
|
|
372
|
-
# Default to end_of_turn if generation completed normally
|
|
373
|
-
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
|
374
|
-
|
|
375
|
-
# Convert tool calls to OpenAI format
|
|
376
|
-
openai_tool_calls = None
|
|
377
|
-
if decoded_message.tool_calls:
|
|
378
|
-
from llama_stack_api import (
|
|
379
|
-
OpenAIChatCompletionToolCall,
|
|
380
|
-
OpenAIChatCompletionToolCallFunction,
|
|
381
|
-
)
|
|
382
|
-
|
|
383
|
-
openai_tool_calls = [
|
|
384
|
-
OpenAIChatCompletionToolCall(
|
|
385
|
-
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
|
386
|
-
id=f"call_{uuid.uuid4().hex[:24]}",
|
|
387
|
-
type="function",
|
|
388
|
-
function=OpenAIChatCompletionToolCallFunction(
|
|
389
|
-
name=str(tc.tool_name),
|
|
390
|
-
arguments=tc.arguments,
|
|
391
|
-
),
|
|
392
|
-
)
|
|
393
|
-
for tc in decoded_message.tool_calls
|
|
394
|
-
]
|
|
395
|
-
|
|
396
|
-
# Determine finish_reason based on whether tool calls are present
|
|
397
|
-
finish_reason = "tool_calls" if openai_tool_calls else "stop"
|
|
398
|
-
|
|
399
|
-
# Extract content from decoded message
|
|
400
|
-
content = ""
|
|
401
|
-
if isinstance(decoded_message.content, str):
|
|
402
|
-
content = decoded_message.content
|
|
403
|
-
elif isinstance(decoded_message.content, list):
|
|
404
|
-
for item in decoded_message.content:
|
|
405
|
-
if isinstance(item, RawTextItem):
|
|
406
|
-
content += item.text
|
|
407
|
-
|
|
408
|
-
# Create OpenAI response
|
|
409
|
-
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
|
410
|
-
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
411
|
-
created = int(time.time())
|
|
412
|
-
|
|
413
|
-
return OpenAIChatCompletion(
|
|
414
|
-
id=response_id,
|
|
415
|
-
object="chat.completion",
|
|
416
|
-
created=created,
|
|
417
|
-
model=params.model,
|
|
418
|
-
choices=[
|
|
419
|
-
OpenAIChoice(
|
|
420
|
-
index=0,
|
|
421
|
-
message=OpenAIAssistantMessageParam(
|
|
422
|
-
role="assistant",
|
|
423
|
-
content=content,
|
|
424
|
-
tool_calls=openai_tool_calls,
|
|
425
|
-
),
|
|
426
|
-
finish_reason=finish_reason,
|
|
427
|
-
logprobs=None,
|
|
428
|
-
)
|
|
429
|
-
],
|
|
430
|
-
usage=OpenAIChatCompletionUsage(
|
|
431
|
-
prompt_tokens=0, # TODO: calculate properly
|
|
432
|
-
completion_tokens=0, # TODO: calculate properly
|
|
433
|
-
total_tokens=0, # TODO: calculate properly
|
|
434
|
-
),
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
async def _stream_chat_completion(
|
|
438
|
-
self,
|
|
439
|
-
generator,
|
|
440
|
-
params: OpenAIChatCompletionRequestWithExtraBody,
|
|
441
|
-
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
|
442
|
-
"""Stream chat completion chunks as they're generated."""
|
|
443
|
-
from llama_stack.models.llama.datatypes import StopReason
|
|
444
|
-
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
|
445
|
-
from llama_stack_api import (
|
|
446
|
-
OpenAIChatCompletionChunk,
|
|
447
|
-
OpenAIChatCompletionToolCall,
|
|
448
|
-
OpenAIChatCompletionToolCallFunction,
|
|
449
|
-
OpenAIChoiceDelta,
|
|
450
|
-
OpenAIChunkChoice,
|
|
451
|
-
)
|
|
452
|
-
|
|
453
|
-
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
454
|
-
created = int(time.time())
|
|
455
|
-
generated_text = ""
|
|
456
|
-
|
|
457
|
-
# Yield chunks as tokens are generated
|
|
458
|
-
for result_batch in generator:
|
|
459
|
-
for result in result_batch:
|
|
460
|
-
if result.ignore_token or result.source != "output":
|
|
461
|
-
continue
|
|
462
|
-
|
|
463
|
-
generated_text += result.text
|
|
464
|
-
|
|
465
|
-
# Yield delta chunk with the new text
|
|
466
|
-
chunk = OpenAIChatCompletionChunk(
|
|
467
|
-
id=response_id,
|
|
468
|
-
object="chat.completion.chunk",
|
|
469
|
-
created=created,
|
|
470
|
-
model=params.model,
|
|
471
|
-
choices=[
|
|
472
|
-
OpenAIChunkChoice(
|
|
473
|
-
index=0,
|
|
474
|
-
delta=OpenAIChoiceDelta(
|
|
475
|
-
role="assistant",
|
|
476
|
-
content=result.text,
|
|
477
|
-
),
|
|
478
|
-
finish_reason="",
|
|
479
|
-
logprobs=None,
|
|
480
|
-
)
|
|
481
|
-
],
|
|
482
|
-
)
|
|
483
|
-
yield chunk
|
|
484
|
-
|
|
485
|
-
# After generation completes, decode the full message to extract tool calls
|
|
486
|
-
decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
|
|
487
|
-
|
|
488
|
-
# If tool calls are present, yield a final chunk with tool_calls
|
|
489
|
-
if decoded_message.tool_calls:
|
|
490
|
-
openai_tool_calls = [
|
|
491
|
-
OpenAIChatCompletionToolCall(
|
|
492
|
-
# generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
|
|
493
|
-
id=f"call_{uuid.uuid4().hex[:24]}",
|
|
494
|
-
type="function",
|
|
495
|
-
function=OpenAIChatCompletionToolCallFunction(
|
|
496
|
-
name=str(tc.tool_name),
|
|
497
|
-
arguments=tc.arguments,
|
|
498
|
-
),
|
|
499
|
-
)
|
|
500
|
-
for tc in decoded_message.tool_calls
|
|
501
|
-
]
|
|
502
|
-
|
|
503
|
-
# Yield chunk with tool_calls
|
|
504
|
-
chunk = OpenAIChatCompletionChunk(
|
|
505
|
-
id=response_id,
|
|
506
|
-
object="chat.completion.chunk",
|
|
507
|
-
created=created,
|
|
508
|
-
model=params.model,
|
|
509
|
-
choices=[
|
|
510
|
-
OpenAIChunkChoice(
|
|
511
|
-
index=0,
|
|
512
|
-
delta=OpenAIChoiceDelta(
|
|
513
|
-
role="assistant",
|
|
514
|
-
tool_calls=openai_tool_calls,
|
|
515
|
-
),
|
|
516
|
-
finish_reason="",
|
|
517
|
-
logprobs=None,
|
|
518
|
-
)
|
|
519
|
-
],
|
|
520
|
-
)
|
|
521
|
-
yield chunk
|
|
522
|
-
|
|
523
|
-
finish_reason = "tool_calls"
|
|
524
|
-
else:
|
|
525
|
-
finish_reason = "stop"
|
|
526
|
-
|
|
527
|
-
# Yield final chunk with finish_reason
|
|
528
|
-
final_chunk = OpenAIChatCompletionChunk(
|
|
529
|
-
id=response_id,
|
|
530
|
-
object="chat.completion.chunk",
|
|
531
|
-
created=created,
|
|
532
|
-
model=params.model,
|
|
533
|
-
choices=[
|
|
534
|
-
OpenAIChunkChoice(
|
|
535
|
-
index=0,
|
|
536
|
-
delta=OpenAIChoiceDelta(),
|
|
537
|
-
finish_reason=finish_reason,
|
|
538
|
-
logprobs=None,
|
|
539
|
-
)
|
|
540
|
-
],
|
|
541
|
-
)
|
|
542
|
-
yield final_chunk
|
|
@@ -1,77 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
from collections.abc import Callable
|
|
8
|
-
from functools import partial
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
|
12
|
-
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
|
13
|
-
|
|
14
|
-
from .parallel_utils import ModelParallelProcessGroup
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class ModelRunner:
|
|
18
|
-
def __init__(self, llama):
|
|
19
|
-
self.llama = llama
|
|
20
|
-
|
|
21
|
-
def __call__(self, task: Any):
|
|
22
|
-
task_type = task[0]
|
|
23
|
-
if task_type == "chat_completion":
|
|
24
|
-
# task[1] is [params, raw_messages]
|
|
25
|
-
params, raw_messages = task[1]
|
|
26
|
-
return self.llama.chat_completion(params, raw_messages)
|
|
27
|
-
else:
|
|
28
|
-
raise ValueError(f"Unexpected task type {task_type}")
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def init_model_cb(
|
|
32
|
-
builder_fn: Callable,
|
|
33
|
-
params: list[Any],
|
|
34
|
-
):
|
|
35
|
-
llama = builder_fn(*params)
|
|
36
|
-
return ModelRunner(llama)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class LlamaModelParallelGenerator:
|
|
40
|
-
"""
|
|
41
|
-
This abstraction exists so
|
|
42
|
-
- we can run model parallel code without needing to run the CLIs via torchrun
|
|
43
|
-
- this also enables use model parallel code within a notebook context.
|
|
44
|
-
|
|
45
|
-
A Context Manager is used to ensure that the model parallel process is started and stopped
|
|
46
|
-
correctly. This does make the ergonomics a little awkward, because it isn't immediately
|
|
47
|
-
clear at the callsite why we need to use a context manager.
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
model_parallel_size: int,
|
|
53
|
-
builder_fn: Callable,
|
|
54
|
-
builder_params: list[Any],
|
|
55
|
-
formatter: Llama3ChatFormat | Llama4ChatFormat,
|
|
56
|
-
):
|
|
57
|
-
self.model_parallel_size = model_parallel_size
|
|
58
|
-
self.builder_fn = builder_fn
|
|
59
|
-
self.builder_params = builder_params
|
|
60
|
-
self.formatter = formatter
|
|
61
|
-
|
|
62
|
-
def start(self):
|
|
63
|
-
self.__enter__()
|
|
64
|
-
|
|
65
|
-
def stop(self):
|
|
66
|
-
self.__exit__(None, None, None)
|
|
67
|
-
|
|
68
|
-
def __enter__(self):
|
|
69
|
-
self.group = ModelParallelProcessGroup(
|
|
70
|
-
self.model_parallel_size,
|
|
71
|
-
init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
|
|
72
|
-
)
|
|
73
|
-
self.group.start()
|
|
74
|
-
return self
|
|
75
|
-
|
|
76
|
-
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
77
|
-
self.group.stop()
|