llama-stack-api 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack_api/__init__.py +175 -20
- llama_stack_api/agents/__init__.py +38 -0
- llama_stack_api/agents/api.py +52 -0
- llama_stack_api/agents/fastapi_routes.py +268 -0
- llama_stack_api/agents/models.py +181 -0
- llama_stack_api/common/errors.py +15 -0
- llama_stack_api/connectors/__init__.py +38 -0
- llama_stack_api/connectors/api.py +50 -0
- llama_stack_api/connectors/fastapi_routes.py +103 -0
- llama_stack_api/connectors/models.py +103 -0
- llama_stack_api/conversations/__init__.py +61 -0
- llama_stack_api/conversations/api.py +44 -0
- llama_stack_api/conversations/fastapi_routes.py +177 -0
- llama_stack_api/conversations/models.py +245 -0
- llama_stack_api/datasetio/__init__.py +34 -0
- llama_stack_api/datasetio/api.py +42 -0
- llama_stack_api/datasetio/fastapi_routes.py +94 -0
- llama_stack_api/datasetio/models.py +48 -0
- llama_stack_api/eval/__init__.py +55 -0
- llama_stack_api/eval/api.py +51 -0
- llama_stack_api/eval/compat.py +300 -0
- llama_stack_api/eval/fastapi_routes.py +126 -0
- llama_stack_api/eval/models.py +141 -0
- llama_stack_api/inference/__init__.py +207 -0
- llama_stack_api/inference/api.py +93 -0
- llama_stack_api/inference/fastapi_routes.py +243 -0
- llama_stack_api/inference/models.py +1035 -0
- llama_stack_api/models/__init__.py +47 -0
- llama_stack_api/models/api.py +38 -0
- llama_stack_api/models/fastapi_routes.py +104 -0
- llama_stack_api/{models.py → models/models.py} +65 -79
- llama_stack_api/openai_responses.py +32 -6
- llama_stack_api/post_training/__init__.py +73 -0
- llama_stack_api/post_training/api.py +36 -0
- llama_stack_api/post_training/fastapi_routes.py +116 -0
- llama_stack_api/{post_training.py → post_training/models.py} +55 -86
- llama_stack_api/prompts/__init__.py +47 -0
- llama_stack_api/prompts/api.py +44 -0
- llama_stack_api/prompts/fastapi_routes.py +163 -0
- llama_stack_api/prompts/models.py +177 -0
- llama_stack_api/resource.py +0 -1
- llama_stack_api/safety/__init__.py +37 -0
- llama_stack_api/safety/api.py +29 -0
- llama_stack_api/safety/datatypes.py +83 -0
- llama_stack_api/safety/fastapi_routes.py +55 -0
- llama_stack_api/safety/models.py +38 -0
- llama_stack_api/schema_utils.py +47 -4
- llama_stack_api/scoring/__init__.py +66 -0
- llama_stack_api/scoring/api.py +35 -0
- llama_stack_api/scoring/fastapi_routes.py +67 -0
- llama_stack_api/scoring/models.py +81 -0
- llama_stack_api/scoring_functions/__init__.py +50 -0
- llama_stack_api/scoring_functions/api.py +39 -0
- llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
- llama_stack_api/{scoring_functions.py → scoring_functions/models.py} +67 -64
- llama_stack_api/shields/__init__.py +41 -0
- llama_stack_api/shields/api.py +39 -0
- llama_stack_api/shields/fastapi_routes.py +104 -0
- llama_stack_api/shields/models.py +74 -0
- llama_stack_api/validators.py +46 -0
- llama_stack_api/vector_io/__init__.py +88 -0
- llama_stack_api/vector_io/api.py +234 -0
- llama_stack_api/vector_io/fastapi_routes.py +447 -0
- llama_stack_api/{vector_io.py → vector_io/models.py} +99 -377
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
- llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
- llama_stack_api/agents.py +0 -173
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/eval.py +0 -137
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/safety.py +0 -132
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/shields.py +0 -93
- llama_stack_api-0.4.4.dist-info/RECORD +0 -70
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Pydantic models for Prompts API requests and responses.
|
|
8
|
+
|
|
9
|
+
This module defines the request and response models for the Prompts API
|
|
10
|
+
using Pydantic with Field descriptions for OpenAPI schema generation.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import re
|
|
14
|
+
import secrets
|
|
15
|
+
|
|
16
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
17
|
+
|
|
18
|
+
from llama_stack_api.schema_utils import json_schema_type
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@json_schema_type
|
|
22
|
+
class Prompt(BaseModel):
|
|
23
|
+
"""A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack."""
|
|
24
|
+
|
|
25
|
+
prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
|
|
26
|
+
version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
|
|
27
|
+
prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
|
|
28
|
+
variables: list[str] = Field(
|
|
29
|
+
default_factory=list, description="List of variable names that can be used in the prompt template"
|
|
30
|
+
)
|
|
31
|
+
is_default: bool = Field(
|
|
32
|
+
default=False, description="Boolean indicating whether this version is the default version"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
@field_validator("prompt_id")
|
|
36
|
+
@classmethod
|
|
37
|
+
def validate_prompt_id(cls, prompt_id: str) -> str:
|
|
38
|
+
if not isinstance(prompt_id, str):
|
|
39
|
+
raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
|
|
40
|
+
|
|
41
|
+
if not prompt_id.startswith("pmpt_"):
|
|
42
|
+
raise ValueError("prompt_id must start with 'pmpt_' prefix")
|
|
43
|
+
|
|
44
|
+
hex_part = prompt_id[5:]
|
|
45
|
+
if len(hex_part) != 48:
|
|
46
|
+
raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
|
|
47
|
+
|
|
48
|
+
for char in hex_part:
|
|
49
|
+
if char not in "0123456789abcdef":
|
|
50
|
+
raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
|
|
51
|
+
|
|
52
|
+
return prompt_id
|
|
53
|
+
|
|
54
|
+
@field_validator("version")
|
|
55
|
+
@classmethod
|
|
56
|
+
def validate_version(cls, prompt_version: int) -> int:
|
|
57
|
+
if prompt_version < 1:
|
|
58
|
+
raise ValueError("version must be >= 1")
|
|
59
|
+
return prompt_version
|
|
60
|
+
|
|
61
|
+
@model_validator(mode="after")
|
|
62
|
+
def validate_prompt_variables(self):
|
|
63
|
+
"""Validate that all variables used in the prompt are declared in the variables list."""
|
|
64
|
+
if not self.prompt:
|
|
65
|
+
return self
|
|
66
|
+
|
|
67
|
+
prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
|
|
68
|
+
declared_variables = set(self.variables)
|
|
69
|
+
|
|
70
|
+
undeclared = prompt_variables - declared_variables
|
|
71
|
+
if undeclared:
|
|
72
|
+
raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
|
|
73
|
+
|
|
74
|
+
return self
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def generate_prompt_id(cls) -> str:
|
|
78
|
+
# Generate 48 hex characters (24 bytes)
|
|
79
|
+
random_bytes = secrets.token_bytes(24)
|
|
80
|
+
hex_string = random_bytes.hex()
|
|
81
|
+
return f"pmpt_{hex_string}"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@json_schema_type
|
|
85
|
+
class ListPromptsResponse(BaseModel):
|
|
86
|
+
"""Response model to list prompts."""
|
|
87
|
+
|
|
88
|
+
data: list[Prompt]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
# Request models for each endpoint
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@json_schema_type
|
|
95
|
+
class ListPromptVersionsRequest(BaseModel):
|
|
96
|
+
"""Request model for listing all versions of a prompt."""
|
|
97
|
+
|
|
98
|
+
prompt_id: str = Field(..., description="The identifier of the prompt to list versions for.")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@json_schema_type
|
|
102
|
+
class GetPromptRequest(BaseModel):
|
|
103
|
+
"""Request model for getting a prompt by ID and optional version."""
|
|
104
|
+
|
|
105
|
+
prompt_id: str = Field(..., description="The identifier of the prompt to get.")
|
|
106
|
+
version: int | None = Field(default=None, description="The version of the prompt to get (defaults to latest).")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@json_schema_type
|
|
110
|
+
class CreatePromptRequest(BaseModel):
|
|
111
|
+
"""Request model for creating a new prompt."""
|
|
112
|
+
|
|
113
|
+
prompt: str = Field(..., description="The prompt text content with variable placeholders.")
|
|
114
|
+
variables: list[str] | None = Field(
|
|
115
|
+
default=None, description="List of variable names that can be used in the prompt template."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@json_schema_type
|
|
120
|
+
class UpdatePromptBodyRequest(BaseModel):
|
|
121
|
+
"""Request body model for updating a prompt."""
|
|
122
|
+
|
|
123
|
+
prompt: str = Field(..., description="The updated prompt text content.")
|
|
124
|
+
version: int = Field(..., description="The current version of the prompt being updated.")
|
|
125
|
+
variables: list[str] | None = Field(
|
|
126
|
+
default=None, description="Updated list of variable names that can be used in the prompt template."
|
|
127
|
+
)
|
|
128
|
+
set_as_default: bool = Field(default=True, description="Set the new version as the default (default=True).")
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@json_schema_type
|
|
132
|
+
class UpdatePromptRequest(BaseModel):
|
|
133
|
+
"""Request model for updating a prompt (combines path and body parameters)."""
|
|
134
|
+
|
|
135
|
+
prompt_id: str = Field(..., description="The identifier of the prompt to update.")
|
|
136
|
+
prompt: str = Field(..., description="The updated prompt text content.")
|
|
137
|
+
version: int = Field(..., description="The current version of the prompt being updated.")
|
|
138
|
+
variables: list[str] | None = Field(
|
|
139
|
+
default=None, description="Updated list of variable names that can be used in the prompt template."
|
|
140
|
+
)
|
|
141
|
+
set_as_default: bool = Field(default=True, description="Set the new version as the default (default=True).")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@json_schema_type
|
|
145
|
+
class DeletePromptRequest(BaseModel):
|
|
146
|
+
"""Request model for deleting a prompt."""
|
|
147
|
+
|
|
148
|
+
prompt_id: str = Field(..., description="The identifier of the prompt to delete.")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@json_schema_type
|
|
152
|
+
class SetDefaultVersionBodyRequest(BaseModel):
|
|
153
|
+
"""Request body model for setting the default version of a prompt."""
|
|
154
|
+
|
|
155
|
+
version: int = Field(..., description="The version to set as default.")
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@json_schema_type
|
|
159
|
+
class SetDefaultVersionRequest(BaseModel):
|
|
160
|
+
"""Request model for setting the default version of a prompt (combines path and body parameters)."""
|
|
161
|
+
|
|
162
|
+
prompt_id: str = Field(..., description="The identifier of the prompt.")
|
|
163
|
+
version: int = Field(..., description="The version to set as default.")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
__all__ = [
|
|
167
|
+
"CreatePromptRequest",
|
|
168
|
+
"DeletePromptRequest",
|
|
169
|
+
"GetPromptRequest",
|
|
170
|
+
"ListPromptVersionsRequest",
|
|
171
|
+
"ListPromptsResponse",
|
|
172
|
+
"Prompt",
|
|
173
|
+
"SetDefaultVersionBodyRequest",
|
|
174
|
+
"SetDefaultVersionRequest",
|
|
175
|
+
"UpdatePromptBodyRequest",
|
|
176
|
+
"UpdatePromptRequest",
|
|
177
|
+
]
|
llama_stack_api/resource.py
CHANGED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Safety API protocol and models.
|
|
8
|
+
|
|
9
|
+
This module contains the Safety protocol definition for content moderation and safety shields.
|
|
10
|
+
Pydantic models are defined in llama_stack_api.safety.models.
|
|
11
|
+
The FastAPI router is defined in llama_stack_api.safety.fastapi_routes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from . import fastapi_routes
|
|
15
|
+
from .api import Safety
|
|
16
|
+
from .datatypes import (
|
|
17
|
+
ModerationObject,
|
|
18
|
+
ModerationObjectResults,
|
|
19
|
+
RunShieldResponse,
|
|
20
|
+
SafetyViolation,
|
|
21
|
+
ShieldStore,
|
|
22
|
+
ViolationLevel,
|
|
23
|
+
)
|
|
24
|
+
from .models import RunModerationRequest, RunShieldRequest
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"Safety",
|
|
28
|
+
"ShieldStore",
|
|
29
|
+
"ModerationObject",
|
|
30
|
+
"ModerationObjectResults",
|
|
31
|
+
"ViolationLevel",
|
|
32
|
+
"SafetyViolation",
|
|
33
|
+
"RunShieldResponse",
|
|
34
|
+
"RunShieldRequest",
|
|
35
|
+
"RunModerationRequest",
|
|
36
|
+
"fastapi_routes",
|
|
37
|
+
]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
from llama_stack_api.safety.datatypes import ModerationObject, RunShieldResponse, ShieldStore
|
|
10
|
+
|
|
11
|
+
from .models import RunModerationRequest, RunShieldRequest
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@runtime_checkable
|
|
15
|
+
class Safety(Protocol):
|
|
16
|
+
"""Safety API for content moderation and safety shields.
|
|
17
|
+
|
|
18
|
+
OpenAI-compatible Moderations API with additional shield capabilities.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
shield_store: ShieldStore
|
|
22
|
+
|
|
23
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
24
|
+
"""Run a safety shield on messages."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
|
|
28
|
+
"""Classify if inputs are potentially harmful."""
|
|
29
|
+
...
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any, Protocol
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
|
|
12
|
+
from llama_stack_api.schema_utils import json_schema_type
|
|
13
|
+
from llama_stack_api.shields import GetShieldRequest, Shield
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@json_schema_type
|
|
17
|
+
class ModerationObjectResults(BaseModel):
|
|
18
|
+
"""A moderation result object containing flagged status and category information."""
|
|
19
|
+
|
|
20
|
+
flagged: bool = Field(..., description="Whether any of the below categories are flagged")
|
|
21
|
+
categories: dict[str, bool] | None = Field(
|
|
22
|
+
None, description="A dictionary of the categories, and whether they are flagged or not"
|
|
23
|
+
)
|
|
24
|
+
category_applied_input_types: dict[str, list[str]] | None = Field(
|
|
25
|
+
None, description="A dictionary of the categories along with the input type(s) that the score applies to"
|
|
26
|
+
)
|
|
27
|
+
category_scores: dict[str, float] | None = Field(
|
|
28
|
+
None, description="A dictionary of the categories along with their scores as predicted by model"
|
|
29
|
+
)
|
|
30
|
+
user_message: str | None = Field(None, description="A message to convey to the user about the moderation result")
|
|
31
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata about the moderation")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@json_schema_type
|
|
35
|
+
class ModerationObject(BaseModel):
|
|
36
|
+
"""A moderation object containing the results of content classification."""
|
|
37
|
+
|
|
38
|
+
id: str = Field(..., description="The unique identifier for the moderation request")
|
|
39
|
+
model: str = Field(..., description="The model used to generate the moderation results")
|
|
40
|
+
results: list[ModerationObjectResults] = Field(..., description="A list of moderation result objects")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@json_schema_type
|
|
44
|
+
class ViolationLevel(Enum):
|
|
45
|
+
"""Severity level of a safety violation."""
|
|
46
|
+
|
|
47
|
+
INFO = "info" # Informational level violation that does not require action
|
|
48
|
+
WARN = "warn" # Warning level violation that suggests caution but allows continuation
|
|
49
|
+
ERROR = "error" # Error level violation that requires blocking or intervention
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@json_schema_type
|
|
53
|
+
class SafetyViolation(BaseModel):
|
|
54
|
+
"""Details of a safety violation detected by content moderation."""
|
|
55
|
+
|
|
56
|
+
violation_level: ViolationLevel = Field(..., description="Severity level of the violation")
|
|
57
|
+
user_message: str | None = Field(None, description="Message to convey to the user about the violation")
|
|
58
|
+
metadata: dict[str, Any] = Field(
|
|
59
|
+
default_factory=dict, description="Additional metadata including specific violation codes"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@json_schema_type
|
|
64
|
+
class RunShieldResponse(BaseModel):
|
|
65
|
+
"""Response from running a safety shield."""
|
|
66
|
+
|
|
67
|
+
violation: SafetyViolation | None = Field(None, description="Safety violation detected by the shield, if any")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ShieldStore(Protocol):
|
|
71
|
+
"""Protocol for accessing shields."""
|
|
72
|
+
|
|
73
|
+
async def get_shield(self, request: GetShieldRequest) -> Shield: ...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
__all__ = [
|
|
77
|
+
"ModerationObjectResults",
|
|
78
|
+
"ModerationObject",
|
|
79
|
+
"ViolationLevel",
|
|
80
|
+
"SafetyViolation",
|
|
81
|
+
"RunShieldResponse",
|
|
82
|
+
"ShieldStore",
|
|
83
|
+
]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Annotated
|
|
8
|
+
|
|
9
|
+
from fastapi import APIRouter, Body
|
|
10
|
+
|
|
11
|
+
from llama_stack_api.router_utils import standard_responses
|
|
12
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1
|
|
13
|
+
|
|
14
|
+
from .api import Safety
|
|
15
|
+
from .datatypes import ModerationObject, RunShieldResponse
|
|
16
|
+
from .models import RunModerationRequest, RunShieldRequest
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_router(impl: Safety) -> APIRouter:
|
|
20
|
+
"""Create a FastAPI router for the Safety API."""
|
|
21
|
+
router = APIRouter(
|
|
22
|
+
prefix=f"/{LLAMA_STACK_API_V1}",
|
|
23
|
+
tags=["Safety"],
|
|
24
|
+
responses=standard_responses,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
@router.post(
|
|
28
|
+
"/safety/run-shield",
|
|
29
|
+
response_model=RunShieldResponse,
|
|
30
|
+
summary="Run Shield",
|
|
31
|
+
description="Run a safety shield on messages to check for policy violations.",
|
|
32
|
+
responses={
|
|
33
|
+
200: {"description": "The shield response indicating any violations detected."},
|
|
34
|
+
},
|
|
35
|
+
)
|
|
36
|
+
async def run_shield(
|
|
37
|
+
request: Annotated[RunShieldRequest, Body(...)],
|
|
38
|
+
) -> RunShieldResponse:
|
|
39
|
+
return await impl.run_shield(request)
|
|
40
|
+
|
|
41
|
+
@router.post(
|
|
42
|
+
"/moderations",
|
|
43
|
+
response_model=ModerationObject,
|
|
44
|
+
summary="Create Moderation",
|
|
45
|
+
description="Classifies if text inputs are potentially harmful. OpenAI-compatible endpoint.",
|
|
46
|
+
responses={
|
|
47
|
+
200: {"description": "The moderation results for the input."},
|
|
48
|
+
},
|
|
49
|
+
)
|
|
50
|
+
async def run_moderation(
|
|
51
|
+
request: Annotated[RunModerationRequest, Body(...)],
|
|
52
|
+
) -> ModerationObject:
|
|
53
|
+
return await impl.run_moderation(request)
|
|
54
|
+
|
|
55
|
+
return router
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
from llama_stack_api.inference import OpenAIMessageParam
|
|
10
|
+
from llama_stack_api.schema_utils import json_schema_type
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@json_schema_type
|
|
14
|
+
class RunShieldRequest(BaseModel):
|
|
15
|
+
"""Request model for running a safety shield."""
|
|
16
|
+
|
|
17
|
+
shield_id: str = Field(..., description="The identifier of the shield to run", min_length=1)
|
|
18
|
+
messages: list[OpenAIMessageParam] = Field(..., description="The messages to run the shield on")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@json_schema_type
|
|
22
|
+
class RunModerationRequest(BaseModel):
|
|
23
|
+
"""Request model for running content moderation."""
|
|
24
|
+
|
|
25
|
+
input: str | list[str] = Field(
|
|
26
|
+
...,
|
|
27
|
+
description="Input (or inputs) to classify. Can be a single string or an array of strings.",
|
|
28
|
+
)
|
|
29
|
+
model: str | None = Field(
|
|
30
|
+
None,
|
|
31
|
+
description="The content moderation model to use. If not specified, the default shield will be used.",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"RunShieldRequest",
|
|
37
|
+
"RunModerationRequest",
|
|
38
|
+
]
|
llama_stack_api/schema_utils.py
CHANGED
|
@@ -149,7 +149,6 @@ class WebMethod:
|
|
|
149
149
|
raw_bytes_request_body: bool | None = False
|
|
150
150
|
# A descriptive name of the corresponding span created by tracing
|
|
151
151
|
descriptive_name: str | None = None
|
|
152
|
-
required_scope: str | None = None
|
|
153
152
|
deprecated: bool | None = False
|
|
154
153
|
require_authentication: bool | None = True
|
|
155
154
|
|
|
@@ -166,7 +165,6 @@ def webmethod(
|
|
|
166
165
|
response_examples: list[Any] | None = None,
|
|
167
166
|
raw_bytes_request_body: bool | None = False,
|
|
168
167
|
descriptive_name: str | None = None,
|
|
169
|
-
required_scope: str | None = None,
|
|
170
168
|
deprecated: bool | None = False,
|
|
171
169
|
require_authentication: bool | None = True,
|
|
172
170
|
) -> Callable[[CallableT], CallableT]:
|
|
@@ -177,7 +175,6 @@ def webmethod(
|
|
|
177
175
|
:param public: True if the operation can be invoked without prior authentication.
|
|
178
176
|
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
|
|
179
177
|
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
|
|
180
|
-
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
|
|
181
178
|
:param require_authentication: Whether this endpoint requires authentication (default True).
|
|
182
179
|
"""
|
|
183
180
|
|
|
@@ -191,7 +188,6 @@ def webmethod(
|
|
|
191
188
|
response_examples=response_examples,
|
|
192
189
|
raw_bytes_request_body=raw_bytes_request_body,
|
|
193
190
|
descriptive_name=descriptive_name,
|
|
194
|
-
required_scope=required_scope,
|
|
195
191
|
deprecated=deprecated,
|
|
196
192
|
require_authentication=require_authentication if require_authentication is not None else True,
|
|
197
193
|
)
|
|
@@ -206,3 +202,50 @@ def webmethod(
|
|
|
206
202
|
return func
|
|
207
203
|
|
|
208
204
|
return wrap
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def remove_null_from_anyof(schema: dict, *, add_nullable: bool = False) -> None:
|
|
208
|
+
"""Remove null type from anyOf and optionally add nullable flag.
|
|
209
|
+
|
|
210
|
+
Converts Pydantic's default OpenAPI 3.1 style:
|
|
211
|
+
anyOf: [{type: X, enum: [...]}, {type: null}]
|
|
212
|
+
|
|
213
|
+
To flattened format:
|
|
214
|
+
type: X
|
|
215
|
+
enum: [...]
|
|
216
|
+
nullable: true # only if add_nullable=True
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
schema: The JSON schema dict to modify in-place
|
|
220
|
+
add_nullable: If True, adds 'nullable: true' when null was present.
|
|
221
|
+
Use True for OpenAPI 3.0 compatibility with OpenAI's spec.
|
|
222
|
+
"""
|
|
223
|
+
# Handle anyOf format: anyOf: [{type: string, enum: [...]}, {type: null}]
|
|
224
|
+
if "anyOf" in schema:
|
|
225
|
+
non_null = [s for s in schema["anyOf"] if s.get("type") != "null"]
|
|
226
|
+
has_null = len(non_null) < len(schema["anyOf"])
|
|
227
|
+
|
|
228
|
+
if len(non_null) == 1:
|
|
229
|
+
# Flatten to single type
|
|
230
|
+
only_schema = non_null[0]
|
|
231
|
+
schema.pop("anyOf")
|
|
232
|
+
schema.update(only_schema)
|
|
233
|
+
if has_null and add_nullable:
|
|
234
|
+
schema["nullable"] = True
|
|
235
|
+
|
|
236
|
+
# Handle OpenAPI 3.1 format: type: ['string', 'null']
|
|
237
|
+
elif isinstance(schema.get("type"), list) and "null" in schema["type"]:
|
|
238
|
+
has_null = "null" in schema["type"]
|
|
239
|
+
schema["type"].remove("null")
|
|
240
|
+
if len(schema["type"]) == 1:
|
|
241
|
+
schema["type"] = schema["type"][0]
|
|
242
|
+
if has_null and add_nullable:
|
|
243
|
+
schema["nullable"] = True
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def nullable_openai_style(schema: dict) -> None:
|
|
247
|
+
"""Shorthand for remove_null_from_anyof with add_nullable=True.
|
|
248
|
+
|
|
249
|
+
Use this for fields that need OpenAPI 3.0 nullable style to match OpenAI's spec.
|
|
250
|
+
"""
|
|
251
|
+
remove_null_from_anyof(schema, add_nullable=True)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Scoring API protocol and models.
|
|
8
|
+
|
|
9
|
+
This module contains the Scoring protocol definition.
|
|
10
|
+
Pydantic models are defined in llama_stack_api.scoring.models.
|
|
11
|
+
The FastAPI router is defined in llama_stack_api.scoring.fastapi_routes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Import fastapi_routes for router factory access
|
|
15
|
+
# Import scoring_functions for re-export
|
|
16
|
+
from llama_stack_api.scoring_functions import (
|
|
17
|
+
AggregationFunctionType,
|
|
18
|
+
BasicScoringFnParams,
|
|
19
|
+
CommonScoringFnFields,
|
|
20
|
+
ListScoringFunctionsResponse,
|
|
21
|
+
LLMAsJudgeScoringFnParams,
|
|
22
|
+
RegexParserScoringFnParams,
|
|
23
|
+
ScoringFn,
|
|
24
|
+
ScoringFnInput,
|
|
25
|
+
ScoringFnParams,
|
|
26
|
+
ScoringFnParamsType,
|
|
27
|
+
ScoringFunctions,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from . import fastapi_routes
|
|
31
|
+
|
|
32
|
+
# Import protocol for FastAPI router
|
|
33
|
+
from .api import Scoring, ScoringFunctionStore
|
|
34
|
+
|
|
35
|
+
# Import models for re-export
|
|
36
|
+
from .models import (
|
|
37
|
+
ScoreBatchRequest,
|
|
38
|
+
ScoreBatchResponse,
|
|
39
|
+
ScoreRequest,
|
|
40
|
+
ScoreResponse,
|
|
41
|
+
ScoringResult,
|
|
42
|
+
ScoringResultRow,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
"Scoring",
|
|
47
|
+
"ScoringFunctionStore",
|
|
48
|
+
"ScoringResult",
|
|
49
|
+
"ScoringResultRow",
|
|
50
|
+
"ScoreBatchResponse",
|
|
51
|
+
"ScoreResponse",
|
|
52
|
+
"ScoreRequest",
|
|
53
|
+
"ScoreBatchRequest",
|
|
54
|
+
"AggregationFunctionType",
|
|
55
|
+
"BasicScoringFnParams",
|
|
56
|
+
"CommonScoringFnFields",
|
|
57
|
+
"LLMAsJudgeScoringFnParams",
|
|
58
|
+
"ListScoringFunctionsResponse",
|
|
59
|
+
"RegexParserScoringFnParams",
|
|
60
|
+
"ScoringFn",
|
|
61
|
+
"ScoringFnInput",
|
|
62
|
+
"ScoringFnParams",
|
|
63
|
+
"ScoringFnParamsType",
|
|
64
|
+
"ScoringFunctions",
|
|
65
|
+
"fastapi_routes",
|
|
66
|
+
]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Scoring API protocol definition.
|
|
8
|
+
|
|
9
|
+
This module contains the Scoring protocol definition.
|
|
10
|
+
Pydantic models are defined in llama_stack_api.scoring.models.
|
|
11
|
+
The FastAPI router is defined in llama_stack_api.scoring.fastapi_routes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import Protocol, runtime_checkable
|
|
15
|
+
|
|
16
|
+
from llama_stack_api.scoring_functions import ScoringFn
|
|
17
|
+
|
|
18
|
+
from .models import ScoreBatchRequest, ScoreBatchResponse, ScoreRequest, ScoreResponse
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ScoringFunctionStore(Protocol):
|
|
22
|
+
"""Protocol for storing and retrieving scoring functions."""
|
|
23
|
+
|
|
24
|
+
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@runtime_checkable
|
|
28
|
+
class Scoring(Protocol):
|
|
29
|
+
"""Protocol for scoring operations."""
|
|
30
|
+
|
|
31
|
+
scoring_function_store: ScoringFunctionStore
|
|
32
|
+
|
|
33
|
+
async def score_batch(self, request: ScoreBatchRequest) -> ScoreBatchResponse: ...
|
|
34
|
+
|
|
35
|
+
async def score(self, request: ScoreRequest) -> ScoreResponse: ...
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""FastAPI router for the Scoring API.
|
|
8
|
+
|
|
9
|
+
This module defines the FastAPI router for the Scoring API using standard
|
|
10
|
+
FastAPI route decorators.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Annotated
|
|
14
|
+
|
|
15
|
+
from fastapi import APIRouter, Body
|
|
16
|
+
|
|
17
|
+
from llama_stack_api.router_utils import standard_responses
|
|
18
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1
|
|
19
|
+
|
|
20
|
+
from .api import Scoring
|
|
21
|
+
from .models import ScoreBatchRequest, ScoreBatchResponse, ScoreRequest, ScoreResponse
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_router(impl: Scoring) -> APIRouter:
|
|
25
|
+
"""Create a FastAPI router for the Scoring API.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
impl: The Scoring implementation instance
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
APIRouter configured for the Scoring API
|
|
32
|
+
"""
|
|
33
|
+
router = APIRouter(
|
|
34
|
+
prefix=f"/{LLAMA_STACK_API_V1}",
|
|
35
|
+
tags=["Scoring"],
|
|
36
|
+
responses=standard_responses,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@router.post(
|
|
40
|
+
"/scoring/score",
|
|
41
|
+
response_model=ScoreResponse,
|
|
42
|
+
summary="Score a list of rows.",
|
|
43
|
+
description="Score a list of rows.",
|
|
44
|
+
responses={
|
|
45
|
+
200: {"description": "A ScoreResponse object containing rows and aggregated results."},
|
|
46
|
+
},
|
|
47
|
+
)
|
|
48
|
+
async def score(
|
|
49
|
+
request: Annotated[ScoreRequest, Body(...)],
|
|
50
|
+
) -> ScoreResponse:
|
|
51
|
+
return await impl.score(request)
|
|
52
|
+
|
|
53
|
+
@router.post(
|
|
54
|
+
"/scoring/score-batch",
|
|
55
|
+
response_model=ScoreBatchResponse,
|
|
56
|
+
summary="Score a batch of rows.",
|
|
57
|
+
description="Score a batch of rows.",
|
|
58
|
+
responses={
|
|
59
|
+
200: {"description": "A ScoreBatchResponse."},
|
|
60
|
+
},
|
|
61
|
+
)
|
|
62
|
+
async def score_batch(
|
|
63
|
+
request: Annotated[ScoreBatchRequest, Body(...)],
|
|
64
|
+
) -> ScoreBatchResponse:
|
|
65
|
+
return await impl.score_batch(request)
|
|
66
|
+
|
|
67
|
+
return router
|