llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -11,15 +11,19 @@ from llama_stack.providers.utils.inference.model_registry import ModelRegistryHe
|
|
|
11
11
|
from llama_stack_api import (
|
|
12
12
|
Agents,
|
|
13
13
|
Benchmark,
|
|
14
|
-
BenchmarkConfig,
|
|
15
14
|
BenchmarksProtocolPrivate,
|
|
16
15
|
DatasetIO,
|
|
17
16
|
Datasets,
|
|
18
17
|
Eval,
|
|
19
18
|
EvaluateResponse,
|
|
19
|
+
EvaluateRowsRequest,
|
|
20
20
|
Inference,
|
|
21
21
|
Job,
|
|
22
|
+
JobCancelRequest,
|
|
23
|
+
JobResultRequest,
|
|
22
24
|
JobStatus,
|
|
25
|
+
JobStatusRequest,
|
|
26
|
+
RunEvalRequest,
|
|
23
27
|
Scoring,
|
|
24
28
|
ScoringResult,
|
|
25
29
|
)
|
|
@@ -91,21 +95,20 @@ class NVIDIAEvalImpl(
|
|
|
91
95
|
|
|
92
96
|
async def run_eval(
|
|
93
97
|
self,
|
|
94
|
-
|
|
95
|
-
benchmark_config: BenchmarkConfig,
|
|
98
|
+
request: RunEvalRequest,
|
|
96
99
|
) -> Job:
|
|
97
100
|
"""Run an evaluation job for a benchmark."""
|
|
98
101
|
model = (
|
|
99
|
-
benchmark_config.eval_candidate.model
|
|
100
|
-
if benchmark_config.eval_candidate.type == "model"
|
|
101
|
-
else benchmark_config.eval_candidate.config.model
|
|
102
|
+
request.benchmark_config.eval_candidate.model
|
|
103
|
+
if request.benchmark_config.eval_candidate.type == "model"
|
|
104
|
+
else request.benchmark_config.eval_candidate.config.model
|
|
102
105
|
)
|
|
103
106
|
nvidia_model = self.get_provider_model_id(model) or model
|
|
104
107
|
|
|
105
108
|
result = await self._evaluator_post(
|
|
106
109
|
"/v1/evaluation/jobs",
|
|
107
110
|
{
|
|
108
|
-
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
|
|
111
|
+
"config": f"{DEFAULT_NAMESPACE}/{request.benchmark_id}",
|
|
109
112
|
"target": {"type": "model", "model": nvidia_model},
|
|
110
113
|
},
|
|
111
114
|
)
|
|
@@ -114,20 +117,17 @@ class NVIDIAEvalImpl(
|
|
|
114
117
|
|
|
115
118
|
async def evaluate_rows(
|
|
116
119
|
self,
|
|
117
|
-
|
|
118
|
-
input_rows: list[dict[str, Any]],
|
|
119
|
-
scoring_functions: list[str],
|
|
120
|
-
benchmark_config: BenchmarkConfig,
|
|
120
|
+
request: EvaluateRowsRequest,
|
|
121
121
|
) -> EvaluateResponse:
|
|
122
122
|
raise NotImplementedError()
|
|
123
123
|
|
|
124
|
-
async def job_status(self,
|
|
124
|
+
async def job_status(self, request: JobStatusRequest) -> Job:
|
|
125
125
|
"""Get the status of an evaluation job.
|
|
126
126
|
|
|
127
127
|
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
|
|
128
128
|
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
|
|
129
129
|
"""
|
|
130
|
-
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
|
|
130
|
+
result = await self._evaluator_get(f"/v1/evaluation/jobs/{request.job_id}")
|
|
131
131
|
result_status = result["status"]
|
|
132
132
|
|
|
133
133
|
job_status = JobStatus.failed
|
|
@@ -140,27 +140,28 @@ class NVIDIAEvalImpl(
|
|
|
140
140
|
elif result_status in ["cancelled"]:
|
|
141
141
|
job_status = JobStatus.cancelled
|
|
142
142
|
|
|
143
|
-
return Job(job_id=job_id, status=job_status)
|
|
143
|
+
return Job(job_id=request.job_id, status=job_status)
|
|
144
144
|
|
|
145
|
-
async def job_cancel(self,
|
|
145
|
+
async def job_cancel(self, request: JobCancelRequest) -> None:
|
|
146
146
|
"""Cancel the evaluation job."""
|
|
147
|
-
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
|
|
147
|
+
await self._evaluator_post(f"/v1/evaluation/jobs/{request.job_id}/cancel", {})
|
|
148
148
|
|
|
149
|
-
async def job_result(self,
|
|
149
|
+
async def job_result(self, request: JobResultRequest) -> EvaluateResponse:
|
|
150
150
|
"""Returns the results of the evaluation job."""
|
|
151
151
|
|
|
152
|
-
|
|
152
|
+
job_status_request = JobStatusRequest(benchmark_id=request.benchmark_id, job_id=request.job_id)
|
|
153
|
+
job = await self.job_status(job_status_request)
|
|
153
154
|
status = job.status
|
|
154
155
|
if not status or status != JobStatus.completed:
|
|
155
|
-
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
|
|
156
|
+
raise ValueError(f"Job {request.job_id} not completed. Status: {status.value}")
|
|
156
157
|
|
|
157
|
-
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
|
|
158
|
+
result = await self._evaluator_get(f"/v1/evaluation/jobs/{request.job_id}/results")
|
|
158
159
|
|
|
159
160
|
return EvaluateResponse(
|
|
160
161
|
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
|
|
161
162
|
generations=[],
|
|
162
163
|
scores={
|
|
163
|
-
benchmark_id: ScoringResult(
|
|
164
|
+
request.benchmark_id: ScoringResult(
|
|
164
165
|
score_rows=[],
|
|
165
166
|
aggregated_results=result,
|
|
166
167
|
)
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
-
from pydantic import BaseModel, Field
|
|
9
|
+
from pydantic import BaseModel, Field, SecretStr
|
|
10
10
|
|
|
11
11
|
from llama_stack.core.storage.datatypes import SqlStoreReference
|
|
12
12
|
|
|
@@ -16,8 +16,10 @@ class S3FilesImplConfig(BaseModel):
|
|
|
16
16
|
|
|
17
17
|
bucket_name: str = Field(description="S3 bucket name to store files")
|
|
18
18
|
region: str = Field(default="us-east-1", description="AWS region where the bucket is located")
|
|
19
|
-
aws_access_key_id:
|
|
20
|
-
|
|
19
|
+
aws_access_key_id: SecretStr | None = Field(
|
|
20
|
+
default=None, description="AWS access key ID (optional if using IAM roles)"
|
|
21
|
+
)
|
|
22
|
+
aws_secret_access_key: SecretStr | None = Field(
|
|
21
23
|
default=None, description="AWS secret access key (optional if using IAM roles)"
|
|
22
24
|
)
|
|
23
25
|
endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)")
|
|
@@ -57,8 +57,8 @@ def _create_s3_client(config: S3FilesImplConfig) -> "S3Client":
|
|
|
57
57
|
if config.aws_access_key_id and config.aws_secret_access_key:
|
|
58
58
|
s3_config.update(
|
|
59
59
|
{
|
|
60
|
-
"aws_access_key_id": config.aws_access_key_id,
|
|
61
|
-
"aws_secret_access_key": config.aws_secret_access_key,
|
|
60
|
+
"aws_access_key_id": config.aws_access_key_id.get_secret_value(),
|
|
61
|
+
"aws_secret_access_key": config.aws_secret_access_key.get_secret_value(),
|
|
62
62
|
}
|
|
63
63
|
)
|
|
64
64
|
|
|
@@ -12,6 +12,7 @@ from llama_stack_api import (
|
|
|
12
12
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
13
13
|
OpenAIEmbeddingsResponse,
|
|
14
14
|
OpenAIEmbeddingUsage,
|
|
15
|
+
validate_embeddings_input_is_text,
|
|
15
16
|
)
|
|
16
17
|
|
|
17
18
|
from .config import GeminiConfig
|
|
@@ -37,6 +38,9 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
|
|
37
38
|
Override embeddings method to handle Gemini's missing usage statistics.
|
|
38
39
|
Gemini's embedding API doesn't return usage information, so we provide default values.
|
|
39
40
|
"""
|
|
41
|
+
# Validate that input contains only text, not token arrays
|
|
42
|
+
validate_embeddings_input_is_text(params)
|
|
43
|
+
|
|
40
44
|
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
|
41
45
|
request_params: dict[str, Any] = {
|
|
42
46
|
"model": await self._get_provider_model_id(params.model),
|
|
@@ -24,6 +24,8 @@ class OpenAIInferenceAdapter(OpenAIMixin):
|
|
|
24
24
|
|
|
25
25
|
provider_data_api_key_field: str = "openai_api_key"
|
|
26
26
|
|
|
27
|
+
supports_tokenized_embeddings_input: bool = True
|
|
28
|
+
|
|
27
29
|
embedding_model_metadata: dict[str, dict[str, int]] = {
|
|
28
30
|
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
|
29
31
|
"text-embedding-3-large": {"embedding_dimension": 3072, "context_length": 8192},
|
|
@@ -18,6 +18,7 @@ from llama_stack_api import (
|
|
|
18
18
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
19
19
|
OpenAIEmbeddingsResponse,
|
|
20
20
|
OpenAIEmbeddingUsage,
|
|
21
|
+
validate_embeddings_input_is_text,
|
|
21
22
|
)
|
|
22
23
|
|
|
23
24
|
from .config import TogetherImplConfig
|
|
@@ -74,6 +75,9 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
|
|
74
75
|
- does not support user param, returns 400 Unrecognized request arguments supplied: user
|
|
75
76
|
- does not support dimensions param, returns 400 Unrecognized request arguments supplied: dimensions
|
|
76
77
|
"""
|
|
78
|
+
# Validate that input contains only text, not token arrays
|
|
79
|
+
validate_embeddings_input_is_text(params)
|
|
80
|
+
|
|
77
81
|
# Together support ticket #13332 -> will not fix
|
|
78
82
|
if params.user is not None:
|
|
79
83
|
raise ValueError("Together's embeddings endpoint does not support user param.")
|
|
@@ -19,7 +19,7 @@ class VertexAIProviderDataValidator(BaseModel):
|
|
|
19
19
|
)
|
|
20
20
|
vertex_location: str | None = Field(
|
|
21
21
|
default=None,
|
|
22
|
-
description="Google Cloud location for Vertex AI (e.g.,
|
|
22
|
+
description="Google Cloud location for Vertex AI (e.g., global)",
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
|
|
@@ -31,7 +31,7 @@ class VertexAIConfig(RemoteInferenceProviderConfig):
|
|
|
31
31
|
description="Google Cloud project ID for Vertex AI",
|
|
32
32
|
)
|
|
33
33
|
location: str = Field(
|
|
34
|
-
default="
|
|
34
|
+
default="global",
|
|
35
35
|
description="Google Cloud location for Vertex AI",
|
|
36
36
|
)
|
|
37
37
|
|
|
@@ -39,7 +39,7 @@ class VertexAIConfig(RemoteInferenceProviderConfig):
|
|
|
39
39
|
def sample_run_config(
|
|
40
40
|
cls,
|
|
41
41
|
project: str = "${env.VERTEX_AI_PROJECT:=}",
|
|
42
|
-
location: str = "${env.VERTEX_AI_LOCATION:=
|
|
42
|
+
location: str = "${env.VERTEX_AI_LOCATION:=global}",
|
|
43
43
|
**kwargs,
|
|
44
44
|
) -> dict[str, Any]:
|
|
45
45
|
return {
|
|
@@ -40,9 +40,12 @@ class VertexAIInferenceAdapter(OpenAIMixin):
|
|
|
40
40
|
Get the Vertex AI OpenAI-compatible API base URL.
|
|
41
41
|
|
|
42
42
|
Returns the Vertex AI OpenAI-compatible endpoint URL.
|
|
43
|
-
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
|
43
|
+
Source: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
|
44
44
|
"""
|
|
45
|
-
|
|
45
|
+
if not self.config.location or self.config.location == "global":
|
|
46
|
+
return f"https://aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/global/endpoints/openapi"
|
|
47
|
+
else:
|
|
48
|
+
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
|
46
49
|
|
|
47
50
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
|
48
51
|
"""
|
|
@@ -4,11 +4,16 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
import warnings
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
|
|
9
|
-
from pydantic import Field, HttpUrl, SecretStr,
|
|
10
|
+
from pydantic import Field, HttpUrl, SecretStr, model_validator
|
|
10
11
|
|
|
11
|
-
from llama_stack.providers.utils.inference.model_registry import
|
|
12
|
+
from llama_stack.providers.utils.inference.model_registry import (
|
|
13
|
+
NetworkConfig,
|
|
14
|
+
RemoteInferenceProviderConfig,
|
|
15
|
+
TLSConfig,
|
|
16
|
+
)
|
|
12
17
|
from llama_stack_api import json_schema_type
|
|
13
18
|
|
|
14
19
|
|
|
@@ -27,23 +32,33 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|
|
27
32
|
alias="api_token",
|
|
28
33
|
description="The API token",
|
|
29
34
|
)
|
|
30
|
-
tls_verify: bool | str = Field(
|
|
31
|
-
default=
|
|
32
|
-
|
|
35
|
+
tls_verify: bool | str | None = Field(
|
|
36
|
+
default=None,
|
|
37
|
+
deprecated=True,
|
|
38
|
+
description="DEPRECATED: Use 'network.tls.verify' instead. Whether to verify TLS certificates. "
|
|
39
|
+
"Can be a boolean or a path to a CA certificate file.",
|
|
33
40
|
)
|
|
34
41
|
|
|
35
|
-
@
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
if
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
42
|
+
@model_validator(mode="after")
|
|
43
|
+
def migrate_tls_verify_to_network(self) -> "VLLMInferenceAdapterConfig":
|
|
44
|
+
"""Migrate legacy tls_verify to network.tls.verify for backward compatibility."""
|
|
45
|
+
if self.tls_verify is not None:
|
|
46
|
+
warnings.warn(
|
|
47
|
+
"The 'tls_verify' config option is deprecated. Please use 'network.tls.verify' instead.",
|
|
48
|
+
DeprecationWarning,
|
|
49
|
+
stacklevel=2,
|
|
50
|
+
)
|
|
51
|
+
# Convert string path to Path if needed
|
|
52
|
+
if isinstance(self.tls_verify, str):
|
|
53
|
+
verify_value: bool | Path = Path(self.tls_verify)
|
|
54
|
+
else:
|
|
55
|
+
verify_value = self.tls_verify
|
|
56
|
+
|
|
57
|
+
if self.network is None:
|
|
58
|
+
self.network = NetworkConfig(tls=TLSConfig(verify=verify_value))
|
|
59
|
+
elif self.network.tls is None:
|
|
60
|
+
self.network.tls = TLSConfig(verify=verify_value)
|
|
61
|
+
return self
|
|
47
62
|
|
|
48
63
|
@classmethod
|
|
49
64
|
def sample_run_config(
|
|
@@ -55,5 +70,9 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
|
|
55
70
|
"base_url": base_url,
|
|
56
71
|
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
|
57
72
|
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
|
58
|
-
"
|
|
73
|
+
"network": {
|
|
74
|
+
"tls": {
|
|
75
|
+
"verify": "${env.VLLM_TLS_VERIFY:=true}",
|
|
76
|
+
},
|
|
77
|
+
},
|
|
59
78
|
}
|
|
@@ -73,9 +73,6 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
|
|
73
73
|
except Exception as e:
|
|
74
74
|
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
|
75
75
|
|
|
76
|
-
def get_extra_client_params(self):
|
|
77
|
-
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
|
|
78
|
-
|
|
79
76
|
async def check_model_availability(self, model: str) -> bool:
|
|
80
77
|
"""
|
|
81
78
|
Skip the check when running without authentication.
|
|
@@ -23,6 +23,7 @@ from llama_stack_api import (
|
|
|
23
23
|
OpenAICompletionRequestWithExtraBody,
|
|
24
24
|
OpenAIEmbeddingsRequestWithExtraBody,
|
|
25
25
|
OpenAIEmbeddingsResponse,
|
|
26
|
+
validate_embeddings_input_is_text,
|
|
26
27
|
)
|
|
27
28
|
|
|
28
29
|
logger = get_logger(name=__name__, category="providers::remote::watsonx")
|
|
@@ -147,6 +148,9 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|
|
147
148
|
"""
|
|
148
149
|
Override parent method to add watsonx-specific parameters.
|
|
149
150
|
"""
|
|
151
|
+
# Validate that input contains only text, not token arrays
|
|
152
|
+
validate_embeddings_input_is_text(params)
|
|
153
|
+
|
|
150
154
|
model_obj = await self.model_store.get_model(params.model)
|
|
151
155
|
|
|
152
156
|
# Convert input to list if it's a string
|
|
@@ -5,23 +5,15 @@
|
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
from llama_stack.
|
|
9
|
-
from llama_stack.providers.utils.inference.model_registry import (
|
|
10
|
-
ProviderModelEntry,
|
|
11
|
-
build_hf_repo_model_entry,
|
|
12
|
-
)
|
|
8
|
+
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
|
13
9
|
|
|
14
10
|
_MODEL_ENTRIES = [
|
|
15
11
|
build_hf_repo_model_entry(
|
|
16
12
|
"meta/llama-3.1-8b-instruct",
|
|
17
|
-
|
|
13
|
+
"Llama3.1-8B-Instruct",
|
|
18
14
|
),
|
|
19
15
|
build_hf_repo_model_entry(
|
|
20
16
|
"meta/llama-3.2-1b-instruct",
|
|
21
|
-
|
|
17
|
+
"Llama3.2-1B-Instruct",
|
|
22
18
|
),
|
|
23
19
|
]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def get_model_entries() -> list[ProviderModelEntry]:
|
|
27
|
-
return _MODEL_ENTRIES
|
|
@@ -14,13 +14,15 @@ from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostT
|
|
|
14
14
|
from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
|
|
15
15
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
|
16
16
|
from llama_stack_api import (
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
CancelTrainingJobRequest,
|
|
18
|
+
GetTrainingJobArtifactsRequest,
|
|
19
|
+
GetTrainingJobStatusRequest,
|
|
19
20
|
JobStatus,
|
|
20
21
|
PostTrainingJob,
|
|
21
22
|
PostTrainingJobArtifactsResponse,
|
|
22
23
|
PostTrainingJobStatusResponse,
|
|
23
|
-
|
|
24
|
+
PreferenceOptimizeRequest,
|
|
25
|
+
SupervisedFineTuneRequest,
|
|
24
26
|
)
|
|
25
27
|
|
|
26
28
|
from .models import _MODEL_ENTRIES
|
|
@@ -156,7 +158,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
156
158
|
|
|
157
159
|
return ListNvidiaPostTrainingJobs(data=jobs)
|
|
158
160
|
|
|
159
|
-
async def get_training_job_status(
|
|
161
|
+
async def get_training_job_status(
|
|
162
|
+
self, request: GetTrainingJobStatusRequest
|
|
163
|
+
) -> NvidiaPostTrainingJobStatusResponse:
|
|
160
164
|
"""Get the status of a customization job.
|
|
161
165
|
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
|
162
166
|
|
|
@@ -178,8 +182,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
178
182
|
"""
|
|
179
183
|
response = await self._make_request(
|
|
180
184
|
"GET",
|
|
181
|
-
f"/v1/customization/jobs/{job_uuid}/status",
|
|
182
|
-
params={"job_id": job_uuid},
|
|
185
|
+
f"/v1/customization/jobs/{request.job_uuid}/status",
|
|
186
|
+
params={"job_id": request.job_uuid},
|
|
183
187
|
)
|
|
184
188
|
|
|
185
189
|
api_status = response.pop("status").lower()
|
|
@@ -187,18 +191,20 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
187
191
|
|
|
188
192
|
return NvidiaPostTrainingJobStatusResponse(
|
|
189
193
|
status=JobStatus(mapped_status),
|
|
190
|
-
job_uuid=job_uuid,
|
|
194
|
+
job_uuid=request.job_uuid,
|
|
191
195
|
started_at=datetime.fromisoformat(response.pop("created_at")),
|
|
192
196
|
updated_at=datetime.fromisoformat(response.pop("updated_at")),
|
|
193
197
|
**response,
|
|
194
198
|
)
|
|
195
199
|
|
|
196
|
-
async def cancel_training_job(self,
|
|
200
|
+
async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
|
|
197
201
|
await self._make_request(
|
|
198
|
-
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
|
202
|
+
method="POST", path=f"/v1/customization/jobs/{request.job_uuid}/cancel", params={"job_id": request.job_uuid}
|
|
199
203
|
)
|
|
200
204
|
|
|
201
|
-
async def get_training_job_artifacts(
|
|
205
|
+
async def get_training_job_artifacts(
|
|
206
|
+
self, request: GetTrainingJobArtifactsRequest
|
|
207
|
+
) -> PostTrainingJobArtifactsResponse:
|
|
202
208
|
raise NotImplementedError("Job artifacts are not implemented yet")
|
|
203
209
|
|
|
204
210
|
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
|
@@ -206,13 +212,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
206
212
|
|
|
207
213
|
async def supervised_fine_tune(
|
|
208
214
|
self,
|
|
209
|
-
|
|
210
|
-
training_config: dict[str, Any],
|
|
211
|
-
hyperparam_search_config: dict[str, Any],
|
|
212
|
-
logger_config: dict[str, Any],
|
|
213
|
-
model: str,
|
|
214
|
-
checkpoint_dir: str | None,
|
|
215
|
-
algorithm_config: AlgorithmConfig | None = None,
|
|
215
|
+
request: SupervisedFineTuneRequest,
|
|
216
216
|
) -> NvidiaPostTrainingJob:
|
|
217
217
|
"""
|
|
218
218
|
Fine-tunes a model on a dataset.
|
|
@@ -300,13 +300,16 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
300
300
|
User is informed about unsupported parameters via warnings.
|
|
301
301
|
"""
|
|
302
302
|
|
|
303
|
+
# Convert training_config to dict for internal processing
|
|
304
|
+
training_config = request.training_config.model_dump()
|
|
305
|
+
|
|
303
306
|
# Check for unsupported method parameters
|
|
304
307
|
unsupported_method_params = []
|
|
305
|
-
if checkpoint_dir:
|
|
306
|
-
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
|
|
307
|
-
if hyperparam_search_config:
|
|
308
|
+
if request.checkpoint_dir:
|
|
309
|
+
unsupported_method_params.append(f"checkpoint_dir={request.checkpoint_dir}")
|
|
310
|
+
if request.hyperparam_search_config:
|
|
308
311
|
unsupported_method_params.append("hyperparam_search_config")
|
|
309
|
-
if logger_config:
|
|
312
|
+
if request.logger_config:
|
|
310
313
|
unsupported_method_params.append("logger_config")
|
|
311
314
|
|
|
312
315
|
if unsupported_method_params:
|
|
@@ -344,7 +347,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
344
347
|
|
|
345
348
|
# Prepare base job configuration
|
|
346
349
|
job_config = {
|
|
347
|
-
"config": model,
|
|
350
|
+
"config": request.model,
|
|
348
351
|
"dataset": {
|
|
349
352
|
"name": training_config["data_config"]["dataset_id"],
|
|
350
353
|
"namespace": self.config.dataset_namespace,
|
|
@@ -388,14 +391,14 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
388
391
|
job_config["hyperparameters"].pop("sft")
|
|
389
392
|
|
|
390
393
|
# Handle LoRA-specific configuration
|
|
391
|
-
if algorithm_config:
|
|
392
|
-
if algorithm_config.type == "LoRA":
|
|
393
|
-
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
|
394
|
+
if request.algorithm_config:
|
|
395
|
+
if request.algorithm_config.type == "LoRA":
|
|
396
|
+
warn_unsupported_params(request.algorithm_config, supported_params["lora_config"], "LoRA config")
|
|
394
397
|
job_config["hyperparameters"]["lora"] = {
|
|
395
|
-
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
|
|
398
|
+
k: v for k, v in {"alpha": request.algorithm_config.alpha}.items() if v is not None
|
|
396
399
|
}
|
|
397
400
|
else:
|
|
398
|
-
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
|
401
|
+
raise NotImplementedError(f"Unsupported algorithm config: {request.algorithm_config}")
|
|
399
402
|
|
|
400
403
|
# Create the customization job
|
|
401
404
|
response = await self._make_request(
|
|
@@ -416,12 +419,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|
|
416
419
|
|
|
417
420
|
async def preference_optimize(
|
|
418
421
|
self,
|
|
419
|
-
|
|
420
|
-
finetuned_model: str,
|
|
421
|
-
algorithm_config: DPOAlignmentConfig,
|
|
422
|
-
training_config: TrainingConfig,
|
|
423
|
-
hyperparam_search_config: dict[str, Any],
|
|
424
|
-
logger_config: dict[str, Any],
|
|
422
|
+
request: PreferenceOptimizeRequest,
|
|
425
423
|
) -> PostTrainingJob:
|
|
426
424
|
"""Optimize a model based on preference data."""
|
|
427
425
|
raise NotImplementedError("Preference optimization is not implemented yet")
|
|
@@ -5,12 +5,13 @@
|
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
7
|
import json
|
|
8
|
-
from typing import Any
|
|
9
8
|
|
|
10
9
|
from llama_stack.log import get_logger
|
|
11
10
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
|
11
|
+
from llama_stack.providers.utils.safety import ShieldToModerationMixin
|
|
12
12
|
from llama_stack_api import (
|
|
13
|
-
|
|
13
|
+
GetShieldRequest,
|
|
14
|
+
RunShieldRequest,
|
|
14
15
|
RunShieldResponse,
|
|
15
16
|
Safety,
|
|
16
17
|
SafetyViolation,
|
|
@@ -24,7 +25,7 @@ from .config import BedrockSafetyConfig
|
|
|
24
25
|
logger = get_logger(name=__name__, category="safety::bedrock")
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|
28
|
+
class BedrockSafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
|
|
28
29
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
|
29
30
|
self.config = config
|
|
30
31
|
self.registered_shields = []
|
|
@@ -55,49 +56,31 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|
|
55
56
|
async def unregister_shield(self, identifier: str) -> None:
|
|
56
57
|
pass
|
|
57
58
|
|
|
58
|
-
async def run_shield(
|
|
59
|
-
|
|
60
|
-
) -> RunShieldResponse:
|
|
61
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
59
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
60
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
62
61
|
if not shield:
|
|
63
|
-
raise ValueError(f"Shield {shield_id} not found")
|
|
64
|
-
|
|
65
|
-
"""
|
|
66
|
-
This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
|
67
|
-
```content = [
|
|
68
|
-
{
|
|
69
|
-
"text": {
|
|
70
|
-
"text": "Is the AB503 Product a better investment than the S&P 500?"
|
|
71
|
-
}
|
|
72
|
-
}
|
|
73
|
-
]```
|
|
74
|
-
Incoming messages contain content, role . For now we will extract the content and
|
|
75
|
-
default the "qualifiers": ["query"]
|
|
76
|
-
"""
|
|
62
|
+
raise ValueError(f"Shield {request.shield_id} not found")
|
|
77
63
|
|
|
78
64
|
shield_params = shield.params
|
|
79
|
-
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
|
65
|
+
logger.debug(f"run_shield::{shield_params}::messages={request.messages}")
|
|
80
66
|
|
|
81
|
-
# - convert the messages into format Bedrock expects
|
|
82
67
|
content_messages = []
|
|
83
|
-
for message in messages:
|
|
68
|
+
for message in request.messages:
|
|
84
69
|
content_messages.append({"text": {"text": message.content}})
|
|
85
70
|
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
|
86
71
|
|
|
87
72
|
response = self.bedrock_runtime_client.apply_guardrail(
|
|
88
73
|
guardrailIdentifier=shield.provider_resource_id,
|
|
89
74
|
guardrailVersion=shield_params["guardrailVersion"],
|
|
90
|
-
source="OUTPUT",
|
|
75
|
+
source="OUTPUT",
|
|
91
76
|
content=content_messages,
|
|
92
77
|
)
|
|
93
78
|
if response["action"] == "GUARDRAIL_INTERVENED":
|
|
94
79
|
user_message = ""
|
|
95
80
|
metadata = {}
|
|
96
81
|
for output in response["outputs"]:
|
|
97
|
-
# guardrails returns a list - however for this implementation we will leverage the last values
|
|
98
82
|
user_message = output["text"]
|
|
99
83
|
for assessment in response["assessments"]:
|
|
100
|
-
# guardrails returns a list - however for this implementation we will leverage the last values
|
|
101
84
|
metadata = dict(assessment)
|
|
102
85
|
|
|
103
86
|
return RunShieldResponse(
|
|
@@ -9,9 +9,11 @@ from typing import Any
|
|
|
9
9
|
import requests
|
|
10
10
|
|
|
11
11
|
from llama_stack.log import get_logger
|
|
12
|
+
from llama_stack.providers.utils.safety import ShieldToModerationMixin
|
|
12
13
|
from llama_stack_api import (
|
|
13
|
-
|
|
14
|
+
GetShieldRequest,
|
|
14
15
|
OpenAIMessageParam,
|
|
16
|
+
RunShieldRequest,
|
|
15
17
|
RunShieldResponse,
|
|
16
18
|
Safety,
|
|
17
19
|
SafetyViolation,
|
|
@@ -25,7 +27,7 @@ from .config import NVIDIASafetyConfig
|
|
|
25
27
|
logger = get_logger(name=__name__, category="safety::nvidia")
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|
30
|
+
class NVIDIASafetyAdapter(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
|
|
29
31
|
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
|
30
32
|
"""
|
|
31
33
|
Initialize the NVIDIASafetyAdapter with a given safety configuration.
|
|
@@ -48,32 +50,14 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|
|
48
50
|
async def unregister_shield(self, identifier: str) -> None:
|
|
49
51
|
pass
|
|
50
52
|
|
|
51
|
-
async def run_shield(
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
"""
|
|
55
|
-
Run a safety shield check against the provided messages.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
shield_id (str): The unique identifier for the shield to be used.
|
|
59
|
-
messages (List[Message]): A list of Message objects representing the conversation history.
|
|
60
|
-
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
RunShieldResponse: The response containing safety violation details if any.
|
|
64
|
-
|
|
65
|
-
Raises:
|
|
66
|
-
ValueError: If the shield with the provided shield_id is not found.
|
|
67
|
-
"""
|
|
68
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
53
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
54
|
+
"""Run a safety shield check against the provided messages."""
|
|
55
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
69
56
|
if not shield:
|
|
70
|
-
raise ValueError(f"Shield {shield_id} not found")
|
|
57
|
+
raise ValueError(f"Shield {request.shield_id} not found")
|
|
71
58
|
|
|
72
59
|
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
|
73
|
-
return await self.shield.run(messages)
|
|
74
|
-
|
|
75
|
-
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
|
76
|
-
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
|
60
|
+
return await self.shield.run(request.messages)
|
|
77
61
|
|
|
78
62
|
|
|
79
63
|
class NeMoGuardrails:
|