llama-stack 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +187 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/RECORD +135 -279
- llama_stack-0.5.0.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -40,9 +40,12 @@ class VertexAIInferenceAdapter(OpenAIMixin):
|
|
|
40
40
|
Get the Vertex AI OpenAI-compatible API base URL.
|
|
41
41
|
|
|
42
42
|
Returns the Vertex AI OpenAI-compatible endpoint URL.
|
|
43
|
-
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
|
43
|
+
Source: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
|
44
44
|
"""
|
|
45
|
-
|
|
45
|
+
if not self.config.location or self.config.location == "global":
|
|
46
|
+
return f"https://aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/global/endpoints/openapi"
|
|
47
|
+
else:
|
|
48
|
+
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
|
46
49
|
|
|
47
50
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
|
48
51
|
"""
|
|
@@ -4,11 +4,16 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
import warnings
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
|
|
9
|
-
from pydantic import Field, HttpUrl, SecretStr,
|
|
10
|
+
from pydantic import Field, HttpUrl, SecretStr, model_validator
|
|
10
11
|
|
|
11
|
-
from llama_stack.providers.utils.inference.model_registry import
|
|
12
|
+
from llama_stack.providers.utils.inference.model_registry import (
|
|
13
|
+
NetworkConfig,
|
|
14
|
+
RemoteInferenceProviderConfig,
|
|
15
|
+
TLSConfig,
|
|
16
|
+
)
|
|
12
17
|
from llama_stack_api import json_schema_type
|
|
13
18
|
|
|
14
19
|
|
|
@@ -27,23 +32,33 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|
|
27
32
|
alias="api_token",
|
|
28
33
|
description="The API token",
|
|
29
34
|
)
|
|
30
|
-
tls_verify: bool | str = Field(
|
|
31
|
-
default=
|
|
32
|
-
|
|
35
|
+
tls_verify: bool | str | None = Field(
|
|
36
|
+
default=None,
|
|
37
|
+
deprecated=True,
|
|
38
|
+
description="DEPRECATED: Use 'network.tls.verify' instead. Whether to verify TLS certificates. "
|
|
39
|
+
"Can be a boolean or a path to a CA certificate file.",
|
|
33
40
|
)
|
|
34
41
|
|
|
35
|
-
@
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
if
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
42
|
+
@model_validator(mode="after")
|
|
43
|
+
def migrate_tls_verify_to_network(self) -> "VLLMInferenceAdapterConfig":
|
|
44
|
+
"""Migrate legacy tls_verify to network.tls.verify for backward compatibility."""
|
|
45
|
+
if self.tls_verify is not None:
|
|
46
|
+
warnings.warn(
|
|
47
|
+
"The 'tls_verify' config option is deprecated. Please use 'network.tls.verify' instead.",
|
|
48
|
+
DeprecationWarning,
|
|
49
|
+
stacklevel=2,
|
|
50
|
+
)
|
|
51
|
+
# Convert string path to Path if needed
|
|
52
|
+
if isinstance(self.tls_verify, str):
|
|
53
|
+
verify_value: bool | Path = Path(self.tls_verify)
|
|
54
|
+
else:
|
|
55
|
+
verify_value = self.tls_verify
|
|
56
|
+
|
|
57
|
+
if self.network is None:
|
|
58
|
+
self.network = NetworkConfig(tls=TLSConfig(verify=verify_value))
|
|
59
|
+
elif self.network.tls is None:
|
|
60
|
+
self.network.tls = TLSConfig(verify=verify_value)
|
|
61
|
+
return self
|
|
47
62
|
|
|
48
63
|
@classmethod
|
|
49
64
|
def sample_run_config(
|
|
@@ -55,5 +70,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|
|
55
70
|
"base_url": base_url,
|
|
56
71
|
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
|
57
72
|
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
|
58
|
-
"
|
|
73
|
+
"network": {
|
|
74
|
+
"tls": {
|
|
75
|
+
"verify": "${env.VLLM_TLS_VERIFY:=true}",
|
|
76
|
+
},
|
|
77
|
+
},
|
|
59
78
|
}
|
|
@@ -73,9 +73,6 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|
|
73
73
|
except Exception as e:
|
|
74
74
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
|
75
75
|
|
|
76
|
-
def get_extra_client_params(self):
|
|
77
|
-
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
|
78
|
-
|
|
79
76
|
async def check_model_availability(self, model: str) -> bool:
|
|
80
77
|
"""
|
|
81
78
|
Skip the check when running without authentication.
|
|
@@ -23,6 +23,7 @@ from llama_stack_api import (
|
|
|
23
23
|
OpenAICompletionRequestWithExtraBody,
|
|
24
24
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
25
25
|
OpenAIEmbeddingsResponse,
|
|
26
|
+
validate_embeddings_input_is_text,
|
|
26
27
|
)
|
|
27
28
|
|
|
28
29
|
logger = get_logger(name=__name__, category="providers::remote::watsonx")
|
|
@@ -147,6 +148,9 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|
|
147
148
|
"""
|
|
148
149
|
Override parent method to add watsonx-specific parameters.
|
|
149
150
|
"""
|
|
151
|
+
# Validate that input contains only text, not token arrays
|
|
152
|
+
validate_embeddings_input_is_text(params)
|
|
153
|
+
|
|
150
154
|
model_obj = await self.model_store.get_model(params.model)
|
|
151
155
|
|
|
152
156
|
# Convert input to list if it's a string
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# NVIDIA Post-Training Provider for LlamaStack
|
|
2
|
+
|
|
3
|
+
This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
|
|
4
|
+
|
|
5
|
+
## Features
|
|
6
|
+
|
|
7
|
+
- Supervised fine-tuning of Llama models
|
|
8
|
+
- LoRA fine-tuning support
|
|
9
|
+
- Job management and status tracking
|
|
10
|
+
|
|
11
|
+
## Getting Started
|
|
12
|
+
|
|
13
|
+
### Prerequisites
|
|
14
|
+
|
|
15
|
+
- LlamaStack with NVIDIA configuration
|
|
16
|
+
- Access to Hosted NVIDIA NeMo Customizer service
|
|
17
|
+
- Dataset registered in the Hosted NVIDIA NeMo Customizer service
|
|
18
|
+
- Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
|
|
19
|
+
|
|
20
|
+
### Setup
|
|
21
|
+
|
|
22
|
+
Build the NVIDIA environment:
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
uv pip install llama-stack-client
|
|
26
|
+
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
### Basic Usage using the LlamaStack Python Client
|
|
30
|
+
|
|
31
|
+
### Create Customization Job
|
|
32
|
+
|
|
33
|
+
#### Initialize the client
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
import os
|
|
37
|
+
|
|
38
|
+
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
|
39
|
+
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
|
40
|
+
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
|
41
|
+
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
|
42
|
+
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
|
43
|
+
|
|
44
|
+
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
|
45
|
+
|
|
46
|
+
client = LlamaStackAsLibraryClient("nvidia")
|
|
47
|
+
client.initialize()
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
#### Configure fine-tuning parameters
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|
54
|
+
TrainingConfig,
|
|
55
|
+
TrainingConfigDataConfig,
|
|
56
|
+
TrainingConfigOptimizerConfig,
|
|
57
|
+
)
|
|
58
|
+
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
#### Set up LoRA configuration
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
#### Configure training data
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
data_config = TrainingConfigDataConfig(
|
|
71
|
+
dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
|
|
72
|
+
batch_size=16,
|
|
73
|
+
)
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
#### Configure optimizer
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
optimizer_config = TrainingConfigOptimizerConfig(
|
|
80
|
+
lr=0.0001,
|
|
81
|
+
)
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
#### Set up training configuration
|
|
85
|
+
|
|
86
|
+
```python
|
|
87
|
+
training_config = TrainingConfig(
|
|
88
|
+
n_epochs=2,
|
|
89
|
+
data_config=data_config,
|
|
90
|
+
optimizer_config=optimizer_config,
|
|
91
|
+
)
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
#### Start fine-tuning job
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
training_job = client.post_training.supervised_fine_tune(
|
|
98
|
+
job_uuid="unique-job-id",
|
|
99
|
+
model="meta-llama/Llama-3.1-8B-Instruct",
|
|
100
|
+
checkpoint_dir="",
|
|
101
|
+
algorithm_config=algorithm_config,
|
|
102
|
+
training_config=training_config,
|
|
103
|
+
logger_config={},
|
|
104
|
+
hyperparam_search_config={},
|
|
105
|
+
)
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
### List all jobs
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
jobs = client.post_training.job.list()
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
### Check job status
|
|
115
|
+
|
|
116
|
+
```python
|
|
117
|
+
job_status = client.post_training.job.status(job_uuid="your-job-id")
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
### Cancel a job
|
|
121
|
+
|
|
122
|
+
```python
|
|
123
|
+
client.post_training.job.cancel(job_uuid="your-job-id")
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
### Inference with the fine-tuned model
|
|
127
|
+
|
|
128
|
+
#### 1. Register the model
|
|
129
|
+
|
|
130
|
+
```python
|
|
131
|
+
from llama_stack_api.models import Model, ModelType
|
|
132
|
+
|
|
133
|
+
client.models.register(
|
|
134
|
+
model_id="test-example-model@v1",
|
|
135
|
+
provider_id="nvidia",
|
|
136
|
+
provider_model_id="test-example-model@v1",
|
|
137
|
+
model_type=ModelType.llm,
|
|
138
|
+
)
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
#### 2. Inference with the fine-tuned model
|
|
142
|
+
|
|
143
|
+
```python
|
|
144
|
+
response = client.completions.create(
|
|
145
|
+
prompt="Complete the sentence using one word: Roses are red, violets are ",
|
|
146
|
+
stream=False,
|
|
147
|
+
model="test-example-model@v1",
|
|
148
|
+
max_tokens=50,
|
|
149
|
+
)
|
|
150
|
+
print(response.choices[0].text)
|
|
151
|
+
```
|
|
@@ -5,23 +5,15 @@
|
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
from llama_stack.
|
|
9
|
-
from llama_stack.providers.utils.inference.model_registry import (
|
|
10
|
-
ProviderModelEntry,
|
|
11
|
-
build_hf_repo_model_entry,
|
|
12
|
-
)
|
|
8
|
+
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
|
13
9
|
|
|
14
10
|
_MODEL_ENTRIES = [
|
|
15
11
|
build_hf_repo_model_entry(
|
|
16
12
|
"meta/llama-3.1-8b-instruct",
|
|
17
|
-
|
|
13
|
+
"Llama3.1-8B-Instruct",
|
|
18
14
|
),
|
|
19
15
|
build_hf_repo_model_entry(
|
|
20
16
|
"meta/llama-3.2-1b-instruct",
|
|
21
|
-
|
|
17
|
+
"Llama3.2-1B-Instruct",
|
|
22
18
|
),
|
|
23
19
|
]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def get_model_entries() -> list[ProviderModelEntry]:
|
|
27
|
-
return _MODEL_ENTRIES
|
|
@@ -14,13 +14,15 @@ from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostT
|
|
|
14
14
|
from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
|
|
15
15
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
|
16
16
|
from llama_stack_api import (
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
CancelTrainingJobRequest,
|
|
18
|
+
GetTrainingJobArtifactsRequest,
|
|
19
|
+
GetTrainingJobStatusRequest,
|
|
19
20
|
JobStatus,
|
|
20
21
|
PostTrainingJob,
|
|
21
22
|
PostTrainingJobArtifactsResponse,
|
|
22
23
|
PostTrainingJobStatusResponse,
|
|
23
|
-
|
|
24
|
+
PreferenceOptimizeRequest,
|
|
25
|
+
SupervisedFineTuneRequest,
|
|
24
26
|
)
|
|
25
27
|
|
|
26
28
|
from .models import _MODEL_ENTRIES
|
|
@@ -156,7 +158,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
156
158
|
|
|
157
159
|
return ListNvidiaPostTrainingJobs(data=jobs)
|
|
158
160
|
|
|
159
|
-
async def get_training_job_status(
|
|
161
|
+
async def get_training_job_status(
|
|
162
|
+
self, request: GetTrainingJobStatusRequest
|
|
163
|
+
) -> NvidiaPostTrainingJobStatusResponse:
|
|
160
164
|
"""Get the status of a customization job.
|
|
161
165
|
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
|
162
166
|
|
|
@@ -178,8 +182,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
178
182
|
"""
|
|
179
183
|
response = await self._make_request(
|
|
180
184
|
"GET",
|
|
181
|
-
f"/v1/customization/jobs/{job_uuid}/status",
|
|
182
|
-
params={"job_id": job_uuid},
|
|
185
|
+
f"/v1/customization/jobs/{request.job_uuid}/status",
|
|
186
|
+
params={"job_id": request.job_uuid},
|
|
183
187
|
)
|
|
184
188
|
|
|
185
189
|
api_status = response.pop("status").lower()
|
|
@@ -187,18 +191,20 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
187
191
|
|
|
188
192
|
return NvidiaPostTrainingJobStatusResponse(
|
|
189
193
|
status=JobStatus(mapped_status),
|
|
190
|
-
job_uuid=job_uuid,
|
|
194
|
+
job_uuid=request.job_uuid,
|
|
191
195
|
started_at=datetime.fromisoformat(response.pop("created_at")),
|
|
192
196
|
updated_at=datetime.fromisoformat(response.pop("updated_at")),
|
|
193
197
|
**response,
|
|
194
198
|
)
|
|
195
199
|
|
|
196
|
-
async def cancel_training_job(self,
|
|
200
|
+
async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
|
|
197
201
|
await self._make_request(
|
|
198
|
-
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
|
202
|
+
method="POST", path=f"/v1/customization/jobs/{request.job_uuid}/cancel", params={"job_id": request.job_uuid}
|
|
199
203
|
)
|
|
200
204
|
|
|
201
|
-
async def get_training_job_artifacts(
|
|
205
|
+
async def get_training_job_artifacts(
|
|
206
|
+
self, request: GetTrainingJobArtifactsRequest
|
|
207
|
+
) -> PostTrainingJobArtifactsResponse:
|
|
202
208
|
raise NotImplementedError("Job artifacts are not implemented yet")
|
|
203
209
|
|
|
204
210
|
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
|
@@ -206,13 +212,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
206
212
|
|
|
207
213
|
async def supervised_fine_tune(
|
|
208
214
|
self,
|
|
209
|
-
|
|
210
|
-
training_config: dict[str, Any],
|
|
211
|
-
hyperparam_search_config: dict[str, Any],
|
|
212
|
-
logger_config: dict[str, Any],
|
|
213
|
-
model: str,
|
|
214
|
-
checkpoint_dir: str | None,
|
|
215
|
-
algorithm_config: AlgorithmConfig | None = None,
|
|
215
|
+
request: SupervisedFineTuneRequest,
|
|
216
216
|
) -> NvidiaPostTrainingJob:
|
|
217
217
|
"""
|
|
218
218
|
Fine-tunes a model on a dataset.
|
|
@@ -300,13 +300,16 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
300
300
|
User is informed about unsupported parameters via warnings.
|
|
301
301
|
"""
|
|
302
302
|
|
|
303
|
+
# Convert training_config to dict for internal processing
|
|
304
|
+
training_config = request.training_config.model_dump()
|
|
305
|
+
|
|
303
306
|
# Check for unsupported method parameters
|
|
304
307
|
unsupported_method_params = []
|
|
305
|
-
if checkpoint_dir:
|
|
306
|
-
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
|
|
307
|
-
if hyperparam_search_config:
|
|
308
|
+
if request.checkpoint_dir:
|
|
309
|
+
unsupported_method_params.append(f"checkpoint_dir={request.checkpoint_dir}")
|
|
310
|
+
if request.hyperparam_search_config:
|
|
308
311
|
unsupported_method_params.append("hyperparam_search_config")
|
|
309
|
-
if logger_config:
|
|
312
|
+
if request.logger_config:
|
|
310
313
|
unsupported_method_params.append("logger_config")
|
|
311
314
|
|
|
312
315
|
if unsupported_method_params:
|
|
@@ -344,7 +347,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
344
347
|
|
|
345
348
|
# Prepare base job configuration
|
|
346
349
|
job_config = {
|
|
347
|
-
"config": model,
|
|
350
|
+
"config": request.model,
|
|
348
351
|
"dataset": {
|
|
349
352
|
"name": training_config["data_config"]["dataset_id"],
|
|
350
353
|
"namespace": self.config.dataset_namespace,
|
|
@@ -388,14 +391,14 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
388
391
|
job_config["hyperparameters"].pop("sft")
|
|
389
392
|
|
|
390
393
|
# Handle LoRA-specific configuration
|
|
391
|
-
if algorithm_config:
|
|
392
|
-
if algorithm_config.type == "LoRA":
|
|
393
|
-
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
|
394
|
+
if request.algorithm_config:
|
|
395
|
+
if request.algorithm_config.type == "LoRA":
|
|
396
|
+
warn_unsupported_params(request.algorithm_config, supported_params["lora_config"], "LoRA config")
|
|
394
397
|
job_config["hyperparameters"]["lora"] = {
|
|
395
|
-
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
|
|
398
|
+
k: v for k, v in {"alpha": request.algorithm_config.alpha}.items() if v is not None
|
|
396
399
|
}
|
|
397
400
|
else:
|
|
398
|
-
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
|
401
|
+
raise NotImplementedError(f"Unsupported algorithm config: {request.algorithm_config}")
|
|
399
402
|
|
|
400
403
|
# Create the customization job
|
|
401
404
|
response = await self._make_request(
|
|
@@ -416,12 +419,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
416
419
|
|
|
417
420
|
async def preference_optimize(
|
|
418
421
|
self,
|
|
419
|
-
|
|
420
|
-
finetuned_model: str,
|
|
421
|
-
algorithm_config: DPOAlignmentConfig,
|
|
422
|
-
training_config: TrainingConfig,
|
|
423
|
-
hyperparam_search_config: dict[str, Any],
|
|
424
|
-
logger_config: dict[str, Any],
|
|
422
|
+
request: PreferenceOptimizeRequest,
|
|
425
423
|
) -> PostTrainingJob:
|
|
426
424
|
"""Optimize a model based on preference data."""
|
|
427
425
|
raise NotImplementedError("Preference optimization is not implemented yet")
|
|
@@ -5,12 +5,13 @@
|
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
|
-
from typing import Any
|
|
9
8
|
|
|
10
9
|
from llama_stack.log import get_logger
|
|
11
10
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
|
11
|
+
from llama_stack.providers.utils.safety import ShieldToModerationMixin
|
|
12
12
|
from llama_stack_api import (
|
|
13
|
-
|
|
13
|
+
GetShieldRequest,
|
|
14
|
+
RunShieldRequest,
|
|
14
15
|
RunShieldResponse,
|
|
15
16
|
Safety,
|
|
16
17
|
SafetyViolation,
|
|
@@ -24,7 +25,7 @@ from .config import BedrockSafetyConfig
|
|
|
24
25
|
logger = get_logger(name=__name__, category="safety::bedrock")
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|
28
|
+
class BedrockSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
|
|
28
29
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
|
29
30
|
self.config = config
|
|
30
31
|
self.registered_shields = []
|
|
@@ -55,49 +56,31 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|
|
55
56
|
async def unregister_shield(self, identifier: str) -> None:
|
|
56
57
|
pass
|
|
57
58
|
|
|
58
|
-
async def run_shield(
|
|
59
|
-
|
|
60
|
-
) -> RunShieldResponse:
|
|
61
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
59
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
60
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
62
61
|
if not shield:
|
|
63
|
-
raise ValueError(f"Shield {shield_id} not found")
|
|
64
|
-
|
|
65
|
-
"""
|
|
66
|
-
This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
|
67
|
-
```content = [
|
|
68
|
-
{
|
|
69
|
-
"text": {
|
|
70
|
-
"text": "Is the AB503 Product a better investment than the S&P 500?"
|
|
71
|
-
}
|
|
72
|
-
}
|
|
73
|
-
]```
|
|
74
|
-
Incoming messages contain content, role . For now we will extract the content and
|
|
75
|
-
default the "qualifiers": ["query"]
|
|
76
|
-
"""
|
|
62
|
+
raise ValueError(f"Shield {request.shield_id} not found")
|
|
77
63
|
|
|
78
64
|
shield_params = shield.params
|
|
79
|
-
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
|
65
|
+
logger.debug(f"run_shield::{shield_params}::messages={request.messages}")
|
|
80
66
|
|
|
81
|
-
# - convert the messages into format Bedrock expects
|
|
82
67
|
content_messages = []
|
|
83
|
-
for message in messages:
|
|
68
|
+
for message in request.messages:
|
|
84
69
|
content_messages.append({"text": {"text": message.content}})
|
|
85
70
|
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
|
86
71
|
|
|
87
72
|
response = self.bedrock_runtime_client.apply_guardrail(
|
|
88
73
|
guardrailIdentifier=shield.provider_resource_id,
|
|
89
74
|
guardrailVersion=shield_params["guardrailVersion"],
|
|
90
|
-
source="OUTPUT",
|
|
75
|
+
source="OUTPUT",
|
|
91
76
|
content=content_messages,
|
|
92
77
|
)
|
|
93
78
|
if response["action"] == "GUARDRAIL_INTERVENED":
|
|
94
79
|
user_message = ""
|
|
95
80
|
metadata = {}
|
|
96
81
|
for output in response["outputs"]:
|
|
97
|
-
# guardrails returns a list - however for this implementation we will leverage the last values
|
|
98
82
|
user_message = output["text"]
|
|
99
83
|
for assessment in response["assessments"]:
|
|
100
|
-
# guardrails returns a list - however for this implementation we will leverage the last values
|
|
101
84
|
metadata = dict(assessment)
|
|
102
85
|
|
|
103
86
|
return RunShieldResponse(
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# NVIDIA Safety Provider for LlamaStack
|
|
2
|
+
|
|
3
|
+
This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
|
|
4
|
+
|
|
5
|
+
## Features
|
|
6
|
+
|
|
7
|
+
- Run safety checks for messages
|
|
8
|
+
|
|
9
|
+
## Getting Started
|
|
10
|
+
|
|
11
|
+
### Prerequisites
|
|
12
|
+
|
|
13
|
+
- LlamaStack with NVIDIA configuration
|
|
14
|
+
- Access to NVIDIA NeMo Guardrails service
|
|
15
|
+
- NIM for model to use for safety check is deployed
|
|
16
|
+
|
|
17
|
+
### Setup
|
|
18
|
+
|
|
19
|
+
Build the NVIDIA environment:
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
uv pip install llama-stack-client
|
|
23
|
+
uv run llama stack list-deps nvidia | xargs -L1 uv pip install
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
### Basic Usage using the LlamaStack Python Client
|
|
27
|
+
|
|
28
|
+
#### Initialize the client
|
|
29
|
+
|
|
30
|
+
```python
|
|
31
|
+
import os
|
|
32
|
+
|
|
33
|
+
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
|
34
|
+
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
|
|
35
|
+
|
|
36
|
+
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
|
37
|
+
|
|
38
|
+
client = LlamaStackAsLibraryClient("nvidia")
|
|
39
|
+
client.initialize()
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
#### Create a safety shield
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
from llama_stack_api.safety import Shield
|
|
46
|
+
from llama_stack_api.inference import Message
|
|
47
|
+
|
|
48
|
+
# Create a safety shield
|
|
49
|
+
shield = Shield(
|
|
50
|
+
shield_id="your-shield-id",
|
|
51
|
+
provider_resource_id="safety-model-id", # The model to use for safety checks
|
|
52
|
+
description="Safety checks for content moderation",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Register the shield
|
|
56
|
+
await client.safety.register_shield(shield)
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
#### Run safety checks
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
# Messages to check
|
|
63
|
+
messages = [Message(role="user", content="Your message to check")]
|
|
64
|
+
|
|
65
|
+
# Run safety check
|
|
66
|
+
response = await client.safety.run_shield(
|
|
67
|
+
shield_id="your-shield-id",
|
|
68
|
+
messages=messages,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Check for violations
|
|
72
|
+
if response.violation:
|
|
73
|
+
print(f"Safety violation detected: {response.violation.user_message}")
|
|
74
|
+
print(f"Violation level: {response.violation.violation_level}")
|
|
75
|
+
print(f"Metadata: {response.violation.metadata}")
|
|
76
|
+
else:
|
|
77
|
+
print("No safety violations detected")
|
|
78
|
+
```
|
|
@@ -9,9 +9,11 @@ from typing import Any
|
|
|
9
9
|
import requests
|
|
10
10
|
|
|
11
11
|
from llama_stack.log import get_logger
|
|
12
|
+
from llama_stack.providers.utils.safety import ShieldToModerationMixin
|
|
12
13
|
from llama_stack_api import (
|
|
13
|
-
|
|
14
|
+
GetShieldRequest,
|
|
14
15
|
OpenAIMessageParam,
|
|
16
|
+
RunShieldRequest,
|
|
15
17
|
RunShieldResponse,
|
|
16
18
|
Safety,
|
|
17
19
|
SafetyViolation,
|
|
@@ -25,7 +27,7 @@ from .config import NVIDIASafetyConfig
|
|
|
25
27
|
logger = get_logger(name=__name__, category="safety::nvidia")
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|
30
|
+
class NVIDIASafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
|
|
29
31
|
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
|
30
32
|
"""
|
|
31
33
|
Initialize the NVIDIASafetyAdapter with a given safety configuration.
|
|
@@ -48,32 +50,14 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|
|
48
50
|
async def unregister_shield(self, identifier: str) -> None:
|
|
49
51
|
pass
|
|
50
52
|
|
|
51
|
-
async def run_shield(
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
"""
|
|
55
|
-
Run a safety shield check against the provided messages.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
shield_id (str): The unique identifier for the shield to be used.
|
|
59
|
-
messages (List[Message]): A list of Message objects representing the conversation history.
|
|
60
|
-
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
RunShieldResponse: The response containing safety violation details if any.
|
|
64
|
-
|
|
65
|
-
Raises:
|
|
66
|
-
ValueError: If the shield with the provided shield_id is not found.
|
|
67
|
-
"""
|
|
68
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
53
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
54
|
+
"""Run a safety shield check against the provided messages."""
|
|
55
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
69
56
|
if not shield:
|
|
70
|
-
raise ValueError(f"Shield {shield_id} not found")
|
|
57
|
+
raise ValueError(f"Shield {request.shield_id} not found")
|
|
71
58
|
|
|
72
59
|
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
|
73
|
-
return await self.shield.run(messages)
|
|
74
|
-
|
|
75
|
-
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
|
76
|
-
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
|
60
|
+
return await self.shield.run(request.messages)
|
|
77
61
|
|
|
78
62
|
|
|
79
63
|
class NeMoGuardrails:
|