llama-stack-api 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack_api/__init__.py +1100 -0
- llama_stack_api/admin/__init__.py +45 -0
- llama_stack_api/admin/api.py +72 -0
- llama_stack_api/admin/fastapi_routes.py +117 -0
- llama_stack_api/admin/models.py +113 -0
- llama_stack_api/agents/__init__.py +38 -0
- llama_stack_api/agents/api.py +52 -0
- llama_stack_api/agents/fastapi_routes.py +268 -0
- llama_stack_api/agents/models.py +181 -0
- llama_stack_api/batches/__init__.py +40 -0
- llama_stack_api/batches/api.py +53 -0
- llama_stack_api/batches/fastapi_routes.py +113 -0
- llama_stack_api/batches/models.py +78 -0
- llama_stack_api/benchmarks/__init__.py +43 -0
- llama_stack_api/benchmarks/api.py +39 -0
- llama_stack_api/benchmarks/fastapi_routes.py +109 -0
- llama_stack_api/benchmarks/models.py +109 -0
- llama_stack_api/common/__init__.py +5 -0
- llama_stack_api/common/content_types.py +101 -0
- llama_stack_api/common/errors.py +110 -0
- llama_stack_api/common/job_types.py +38 -0
- llama_stack_api/common/responses.py +77 -0
- llama_stack_api/common/training_types.py +47 -0
- llama_stack_api/common/type_system.py +146 -0
- llama_stack_api/connectors/__init__.py +38 -0
- llama_stack_api/connectors/api.py +50 -0
- llama_stack_api/connectors/fastapi_routes.py +103 -0
- llama_stack_api/connectors/models.py +103 -0
- llama_stack_api/conversations/__init__.py +61 -0
- llama_stack_api/conversations/api.py +44 -0
- llama_stack_api/conversations/fastapi_routes.py +177 -0
- llama_stack_api/conversations/models.py +245 -0
- llama_stack_api/datasetio/__init__.py +34 -0
- llama_stack_api/datasetio/api.py +42 -0
- llama_stack_api/datasetio/fastapi_routes.py +94 -0
- llama_stack_api/datasetio/models.py +48 -0
- llama_stack_api/datasets/__init__.py +61 -0
- llama_stack_api/datasets/api.py +35 -0
- llama_stack_api/datasets/fastapi_routes.py +104 -0
- llama_stack_api/datasets/models.py +152 -0
- llama_stack_api/datatypes.py +373 -0
- llama_stack_api/eval/__init__.py +55 -0
- llama_stack_api/eval/api.py +51 -0
- llama_stack_api/eval/compat.py +300 -0
- llama_stack_api/eval/fastapi_routes.py +126 -0
- llama_stack_api/eval/models.py +141 -0
- llama_stack_api/file_processors/__init__.py +27 -0
- llama_stack_api/file_processors/api.py +64 -0
- llama_stack_api/file_processors/fastapi_routes.py +78 -0
- llama_stack_api/file_processors/models.py +42 -0
- llama_stack_api/files/__init__.py +35 -0
- llama_stack_api/files/api.py +51 -0
- llama_stack_api/files/fastapi_routes.py +124 -0
- llama_stack_api/files/models.py +107 -0
- llama_stack_api/inference/__init__.py +207 -0
- llama_stack_api/inference/api.py +93 -0
- llama_stack_api/inference/fastapi_routes.py +243 -0
- llama_stack_api/inference/models.py +1035 -0
- llama_stack_api/inspect_api/__init__.py +37 -0
- llama_stack_api/inspect_api/api.py +25 -0
- llama_stack_api/inspect_api/fastapi_routes.py +76 -0
- llama_stack_api/inspect_api/models.py +28 -0
- llama_stack_api/internal/__init__.py +9 -0
- llama_stack_api/internal/kvstore.py +28 -0
- llama_stack_api/internal/sqlstore.py +81 -0
- llama_stack_api/models/__init__.py +47 -0
- llama_stack_api/models/api.py +38 -0
- llama_stack_api/models/fastapi_routes.py +104 -0
- llama_stack_api/models/models.py +157 -0
- llama_stack_api/openai_responses.py +1494 -0
- llama_stack_api/post_training/__init__.py +73 -0
- llama_stack_api/post_training/api.py +36 -0
- llama_stack_api/post_training/fastapi_routes.py +116 -0
- llama_stack_api/post_training/models.py +339 -0
- llama_stack_api/prompts/__init__.py +47 -0
- llama_stack_api/prompts/api.py +44 -0
- llama_stack_api/prompts/fastapi_routes.py +163 -0
- llama_stack_api/prompts/models.py +177 -0
- llama_stack_api/providers/__init__.py +33 -0
- llama_stack_api/providers/api.py +16 -0
- llama_stack_api/providers/fastapi_routes.py +57 -0
- llama_stack_api/providers/models.py +24 -0
- llama_stack_api/rag_tool.py +168 -0
- llama_stack_api/resource.py +36 -0
- llama_stack_api/router_utils.py +160 -0
- llama_stack_api/safety/__init__.py +37 -0
- llama_stack_api/safety/api.py +29 -0
- llama_stack_api/safety/datatypes.py +83 -0
- llama_stack_api/safety/fastapi_routes.py +55 -0
- llama_stack_api/safety/models.py +38 -0
- llama_stack_api/schema_utils.py +251 -0
- llama_stack_api/scoring/__init__.py +66 -0
- llama_stack_api/scoring/api.py +35 -0
- llama_stack_api/scoring/fastapi_routes.py +67 -0
- llama_stack_api/scoring/models.py +81 -0
- llama_stack_api/scoring_functions/__init__.py +50 -0
- llama_stack_api/scoring_functions/api.py +39 -0
- llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
- llama_stack_api/scoring_functions/models.py +214 -0
- llama_stack_api/shields/__init__.py +41 -0
- llama_stack_api/shields/api.py +39 -0
- llama_stack_api/shields/fastapi_routes.py +104 -0
- llama_stack_api/shields/models.py +74 -0
- llama_stack_api/tools.py +226 -0
- llama_stack_api/validators.py +46 -0
- llama_stack_api/vector_io/__init__.py +88 -0
- llama_stack_api/vector_io/api.py +234 -0
- llama_stack_api/vector_io/fastapi_routes.py +447 -0
- llama_stack_api/vector_io/models.py +663 -0
- llama_stack_api/vector_stores.py +53 -0
- llama_stack_api/version.py +9 -0
- {llama_stack_api-0.4.3.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
- llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
- llama_stack_api-0.5.0rc1.dist-info/top_level.txt +1 -0
- llama_stack_api-0.4.3.dist-info/RECORD +0 -4
- llama_stack_api-0.4.3.dist-info/top_level.txt +0 -1
- {llama_stack_api-0.4.3.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Utilities for creating FastAPI routers with standard error responses.
|
|
8
|
+
|
|
9
|
+
This module provides standard error response definitions for FastAPI routers.
|
|
10
|
+
These responses use OpenAPI $ref references to component responses defined
|
|
11
|
+
in the OpenAPI specification.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import inspect
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from typing import Annotated, Any, TypeVar
|
|
17
|
+
|
|
18
|
+
from fastapi import Path, Query
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
# OpenAPI extension key to mark routes that don't require authentication.
|
|
22
|
+
# Use this in FastAPI route decorators: @router.get("/health", openapi_extra={PUBLIC_ROUTE_KEY: True})
|
|
23
|
+
PUBLIC_ROUTE_KEY = "x-public"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
standard_responses: dict[int | str, dict[str, Any]] = {
|
|
27
|
+
400: {"$ref": "#/components/responses/BadRequest400"},
|
|
28
|
+
429: {"$ref": "#/components/responses/TooManyRequests429"},
|
|
29
|
+
500: {"$ref": "#/components/responses/InternalServerError500"},
|
|
30
|
+
"default": {"$ref": "#/components/responses/DefaultError"},
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
T = TypeVar("T", bound=BaseModel)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
|
37
|
+
"""Create a FastAPI dependency function from a Pydantic model for query parameters.
|
|
38
|
+
|
|
39
|
+
FastAPI does not natively support using Pydantic models as query parameters
|
|
40
|
+
without a dependency function. Using a dependency function typically leads to
|
|
41
|
+
duplication: field types, default values, and descriptions must be repeated in
|
|
42
|
+
`Query(...)` annotations even though they already exist in the Pydantic model.
|
|
43
|
+
|
|
44
|
+
This function automatically generates a dependency function that extracts query parameters
|
|
45
|
+
from the request and constructs an instance of the Pydantic model. The descriptions and
|
|
46
|
+
defaults are automatically extracted from the model's Field definitions, making the model
|
|
47
|
+
the single source of truth.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
model_class: The Pydantic model class to create a dependency for
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
A dependency function that can be used with FastAPI's Depends()
|
|
54
|
+
```
|
|
55
|
+
"""
|
|
56
|
+
# Build function signature dynamically from model fields
|
|
57
|
+
annotations: dict[str, Any] = {}
|
|
58
|
+
defaults: dict[str, Any] = {}
|
|
59
|
+
|
|
60
|
+
for field_name, field_info in model_class.model_fields.items():
|
|
61
|
+
# Extract description from Field
|
|
62
|
+
description = field_info.description
|
|
63
|
+
|
|
64
|
+
# Create Query annotation with description from model
|
|
65
|
+
query_annotation = Query(description=description) if description else Query()
|
|
66
|
+
|
|
67
|
+
# Create Annotated type with Query
|
|
68
|
+
field_type = field_info.annotation
|
|
69
|
+
annotations[field_name] = Annotated[field_type, query_annotation]
|
|
70
|
+
|
|
71
|
+
# Set default value from model
|
|
72
|
+
if field_info.default is not inspect.Parameter.empty:
|
|
73
|
+
defaults[field_name] = field_info.default
|
|
74
|
+
|
|
75
|
+
# Create the dependency function dynamically
|
|
76
|
+
def dependency_func(**kwargs: Any) -> T:
|
|
77
|
+
return model_class(**kwargs)
|
|
78
|
+
|
|
79
|
+
# Set function signature
|
|
80
|
+
sig_params = []
|
|
81
|
+
for field_name, field_type in annotations.items():
|
|
82
|
+
default = defaults.get(field_name, inspect.Parameter.empty)
|
|
83
|
+
param = inspect.Parameter(
|
|
84
|
+
field_name,
|
|
85
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
86
|
+
default=default,
|
|
87
|
+
annotation=field_type,
|
|
88
|
+
)
|
|
89
|
+
sig_params.append(param)
|
|
90
|
+
|
|
91
|
+
# These attributes are set dynamically at runtime. While mypy can't verify them statically,
|
|
92
|
+
# they are standard Python function attributes that exist on all callable objects at runtime.
|
|
93
|
+
# Setting them allows FastAPI to properly introspect the function signature for dependency injection.
|
|
94
|
+
dependency_func.__signature__ = inspect.Signature(sig_params) # type: ignore[attr-defined]
|
|
95
|
+
dependency_func.__annotations__ = annotations # type: ignore[attr-defined]
|
|
96
|
+
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined]
|
|
97
|
+
|
|
98
|
+
return dependency_func
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
|
102
|
+
"""Create a FastAPI dependency function from a Pydantic model for path parameters.
|
|
103
|
+
|
|
104
|
+
FastAPI requires path parameters to be explicitly annotated with `Path()`. When using
|
|
105
|
+
a Pydantic model that contains path parameters, you typically need a dependency function
|
|
106
|
+
that extracts the path parameter and constructs the model. This leads to duplication:
|
|
107
|
+
the parameter name, type, and description must be repeated in `Path(...)` annotations
|
|
108
|
+
even though they already exist in the Pydantic model.
|
|
109
|
+
|
|
110
|
+
This function automatically generates a dependency function that extracts path parameters
|
|
111
|
+
from the request and constructs an instance of the Pydantic model. The descriptions are
|
|
112
|
+
automatically extracted from the model's Field definitions, making the model the single
|
|
113
|
+
source of truth.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
model_class: The Pydantic model class to create a dependency for. The model should
|
|
117
|
+
have exactly one field that represents the path parameter.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
A dependency function that can be used with FastAPI's Depends()
|
|
121
|
+
```
|
|
122
|
+
"""
|
|
123
|
+
# Get the single field from the model (path parameter models typically have one field)
|
|
124
|
+
if len(model_class.model_fields) != 1:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Path parameter model {model_class.__name__} must have exactly one field, "
|
|
127
|
+
f"but has {len(model_class.model_fields)} fields"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
field_name, field_info = next(iter(model_class.model_fields.items()))
|
|
131
|
+
|
|
132
|
+
# Extract description from Field
|
|
133
|
+
description = field_info.description
|
|
134
|
+
|
|
135
|
+
# Create Path annotation with description from model
|
|
136
|
+
path_annotation = Path(description=description) if description else Path()
|
|
137
|
+
|
|
138
|
+
# Create Annotated type with Path
|
|
139
|
+
field_type = field_info.annotation
|
|
140
|
+
annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]}
|
|
141
|
+
|
|
142
|
+
# Create the dependency function dynamically
|
|
143
|
+
def dependency_func(**kwargs: Any) -> T:
|
|
144
|
+
return model_class(**kwargs)
|
|
145
|
+
|
|
146
|
+
# Set function signature
|
|
147
|
+
param = inspect.Parameter(
|
|
148
|
+
field_name,
|
|
149
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
150
|
+
annotation=annotations[field_name],
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# These attributes are set dynamically at runtime. While mypy can't verify them statically,
|
|
154
|
+
# they are standard Python function attributes that exist on all callable objects at runtime.
|
|
155
|
+
# Setting them allows FastAPI to properly introspect the function signature for dependency injection.
|
|
156
|
+
dependency_func.__signature__ = inspect.Signature([param]) # type: ignore[attr-defined]
|
|
157
|
+
dependency_func.__annotations__ = annotations # type: ignore[attr-defined]
|
|
158
|
+
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined]
|
|
159
|
+
|
|
160
|
+
return dependency_func
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Safety API protocol and models.
|
|
8
|
+
|
|
9
|
+
This module contains the Safety protocol definition for content moderation and safety shields.
|
|
10
|
+
Pydantic models are defined in llama_stack_api.safety.models.
|
|
11
|
+
The FastAPI router is defined in llama_stack_api.safety.fastapi_routes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from . import fastapi_routes
|
|
15
|
+
from .api import Safety
|
|
16
|
+
from .datatypes import (
|
|
17
|
+
ModerationObject,
|
|
18
|
+
ModerationObjectResults,
|
|
19
|
+
RunShieldResponse,
|
|
20
|
+
SafetyViolation,
|
|
21
|
+
ShieldStore,
|
|
22
|
+
ViolationLevel,
|
|
23
|
+
)
|
|
24
|
+
from .models import RunModerationRequest, RunShieldRequest
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"Safety",
|
|
28
|
+
"ShieldStore",
|
|
29
|
+
"ModerationObject",
|
|
30
|
+
"ModerationObjectResults",
|
|
31
|
+
"ViolationLevel",
|
|
32
|
+
"SafetyViolation",
|
|
33
|
+
"RunShieldResponse",
|
|
34
|
+
"RunShieldRequest",
|
|
35
|
+
"RunModerationRequest",
|
|
36
|
+
"fastapi_routes",
|
|
37
|
+
]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
from llama_stack_api.safety.datatypes import ModerationObject, RunShieldResponse, ShieldStore
|
|
10
|
+
|
|
11
|
+
from .models import RunModerationRequest, RunShieldRequest
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@runtime_checkable
|
|
15
|
+
class Safety(Protocol):
|
|
16
|
+
"""Safety API for content moderation and safety shields.
|
|
17
|
+
|
|
18
|
+
OpenAI-compatible Moderations API with additional shield capabilities.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
shield_store: ShieldStore
|
|
22
|
+
|
|
23
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
24
|
+
"""Run a safety shield on messages."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
|
|
28
|
+
"""Classify if inputs are potentially harmful."""
|
|
29
|
+
...
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any, Protocol
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, Field
|
|
11
|
+
|
|
12
|
+
from llama_stack_api.schema_utils import json_schema_type
|
|
13
|
+
from llama_stack_api.shields import GetShieldRequest, Shield
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@json_schema_type
|
|
17
|
+
class ModerationObjectResults(BaseModel):
|
|
18
|
+
"""A moderation result object containing flagged status and category information."""
|
|
19
|
+
|
|
20
|
+
flagged: bool = Field(..., description="Whether any of the below categories are flagged")
|
|
21
|
+
categories: dict[str, bool] | None = Field(
|
|
22
|
+
None, description="A dictionary of the categories, and whether they are flagged or not"
|
|
23
|
+
)
|
|
24
|
+
category_applied_input_types: dict[str, list[str]] | None = Field(
|
|
25
|
+
None, description="A dictionary of the categories along with the input type(s) that the score applies to"
|
|
26
|
+
)
|
|
27
|
+
category_scores: dict[str, float] | None = Field(
|
|
28
|
+
None, description="A dictionary of the categories along with their scores as predicted by model"
|
|
29
|
+
)
|
|
30
|
+
user_message: str | None = Field(None, description="A message to convey to the user about the moderation result")
|
|
31
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata about the moderation")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@json_schema_type
|
|
35
|
+
class ModerationObject(BaseModel):
|
|
36
|
+
"""A moderation object containing the results of content classification."""
|
|
37
|
+
|
|
38
|
+
id: str = Field(..., description="The unique identifier for the moderation request")
|
|
39
|
+
model: str = Field(..., description="The model used to generate the moderation results")
|
|
40
|
+
results: list[ModerationObjectResults] = Field(..., description="A list of moderation result objects")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@json_schema_type
|
|
44
|
+
class ViolationLevel(Enum):
|
|
45
|
+
"""Severity level of a safety violation."""
|
|
46
|
+
|
|
47
|
+
INFO = "info" # Informational level violation that does not require action
|
|
48
|
+
WARN = "warn" # Warning level violation that suggests caution but allows continuation
|
|
49
|
+
ERROR = "error" # Error level violation that requires blocking or intervention
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@json_schema_type
|
|
53
|
+
class SafetyViolation(BaseModel):
|
|
54
|
+
"""Details of a safety violation detected by content moderation."""
|
|
55
|
+
|
|
56
|
+
violation_level: ViolationLevel = Field(..., description="Severity level of the violation")
|
|
57
|
+
user_message: str | None = Field(None, description="Message to convey to the user about the violation")
|
|
58
|
+
metadata: dict[str, Any] = Field(
|
|
59
|
+
default_factory=dict, description="Additional metadata including specific violation codes"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@json_schema_type
|
|
64
|
+
class RunShieldResponse(BaseModel):
|
|
65
|
+
"""Response from running a safety shield."""
|
|
66
|
+
|
|
67
|
+
violation: SafetyViolation | None = Field(None, description="Safety violation detected by the shield, if any")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ShieldStore(Protocol):
|
|
71
|
+
"""Protocol for accessing shields."""
|
|
72
|
+
|
|
73
|
+
async def get_shield(self, request: GetShieldRequest) -> Shield: ...
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
__all__ = [
|
|
77
|
+
"ModerationObjectResults",
|
|
78
|
+
"ModerationObject",
|
|
79
|
+
"ViolationLevel",
|
|
80
|
+
"SafetyViolation",
|
|
81
|
+
"RunShieldResponse",
|
|
82
|
+
"ShieldStore",
|
|
83
|
+
]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Annotated
|
|
8
|
+
|
|
9
|
+
from fastapi import APIRouter, Body
|
|
10
|
+
|
|
11
|
+
from llama_stack_api.router_utils import standard_responses
|
|
12
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1
|
|
13
|
+
|
|
14
|
+
from .api import Safety
|
|
15
|
+
from .datatypes import ModerationObject, RunShieldResponse
|
|
16
|
+
from .models import RunModerationRequest, RunShieldRequest
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_router(impl: Safety) -> APIRouter:
|
|
20
|
+
"""Create a FastAPI router for the Safety API."""
|
|
21
|
+
router = APIRouter(
|
|
22
|
+
prefix=f"/{LLAMA_STACK_API_V1}",
|
|
23
|
+
tags=["Safety"],
|
|
24
|
+
responses=standard_responses,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
@router.post(
|
|
28
|
+
"/safety/run-shield",
|
|
29
|
+
response_model=RunShieldResponse,
|
|
30
|
+
summary="Run Shield",
|
|
31
|
+
description="Run a safety shield on messages to check for policy violations.",
|
|
32
|
+
responses={
|
|
33
|
+
200: {"description": "The shield response indicating any violations detected."},
|
|
34
|
+
},
|
|
35
|
+
)
|
|
36
|
+
async def run_shield(
|
|
37
|
+
request: Annotated[RunShieldRequest, Body(...)],
|
|
38
|
+
) -> RunShieldResponse:
|
|
39
|
+
return await impl.run_shield(request)
|
|
40
|
+
|
|
41
|
+
@router.post(
|
|
42
|
+
"/moderations",
|
|
43
|
+
response_model=ModerationObject,
|
|
44
|
+
summary="Create Moderation",
|
|
45
|
+
description="Classifies if text inputs are potentially harmful. OpenAI-compatible endpoint.",
|
|
46
|
+
responses={
|
|
47
|
+
200: {"description": "The moderation results for the input."},
|
|
48
|
+
},
|
|
49
|
+
)
|
|
50
|
+
async def run_moderation(
|
|
51
|
+
request: Annotated[RunModerationRequest, Body(...)],
|
|
52
|
+
) -> ModerationObject:
|
|
53
|
+
return await impl.run_moderation(request)
|
|
54
|
+
|
|
55
|
+
return router
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
from llama_stack_api.inference import OpenAIMessageParam
|
|
10
|
+
from llama_stack_api.schema_utils import json_schema_type
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@json_schema_type
|
|
14
|
+
class RunShieldRequest(BaseModel):
|
|
15
|
+
"""Request model for running a safety shield."""
|
|
16
|
+
|
|
17
|
+
shield_id: str = Field(..., description="The identifier of the shield to run", min_length=1)
|
|
18
|
+
messages: list[OpenAIMessageParam] = Field(..., description="The messages to run the shield on")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@json_schema_type
|
|
22
|
+
class RunModerationRequest(BaseModel):
|
|
23
|
+
"""Request model for running content moderation."""
|
|
24
|
+
|
|
25
|
+
input: str | list[str] = Field(
|
|
26
|
+
...,
|
|
27
|
+
description="Input (or inputs) to classify. Can be a single string or an array of strings.",
|
|
28
|
+
)
|
|
29
|
+
model: str | None = Field(
|
|
30
|
+
None,
|
|
31
|
+
description="The content moderation model to use. If not specified, the default shield will be used.",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"RunShieldRequest",
|
|
37
|
+
"RunModerationRequest",
|
|
38
|
+
]
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable, Iterable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, Literal, TypeVar
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExtraBodyField[T]:
|
|
13
|
+
"""
|
|
14
|
+
Marker annotation for parameters that arrive via extra_body in the client SDK.
|
|
15
|
+
|
|
16
|
+
These parameters:
|
|
17
|
+
- Will NOT appear in the generated client SDK method signature
|
|
18
|
+
- WILL be documented in OpenAPI spec under x-llama-stack-extra-body-params
|
|
19
|
+
- MUST be passed via the extra_body parameter in client SDK calls
|
|
20
|
+
- WILL be available in server-side method signature with proper typing
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
```python
|
|
24
|
+
async def create_openai_response(
|
|
25
|
+
self,
|
|
26
|
+
input: str,
|
|
27
|
+
model: str,
|
|
28
|
+
shields: Annotated[
|
|
29
|
+
list[str] | None, ExtraBodyField("List of shields to apply")
|
|
30
|
+
] = None,
|
|
31
|
+
) -> ResponseObject:
|
|
32
|
+
# shields is available here with proper typing
|
|
33
|
+
if shields:
|
|
34
|
+
print(f"Using shields: {shields}")
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Client usage:
|
|
38
|
+
```python
|
|
39
|
+
client.responses.create(
|
|
40
|
+
input="hello", model="llama-3", extra_body={"shields": ["shield-1"]}
|
|
41
|
+
)
|
|
42
|
+
```
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, description: str | None = None):
|
|
46
|
+
self.description = description
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
SchemaSource = Literal["json_schema_type", "registered_schema", "dynamic_schema"]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class SchemaInfo:
|
|
54
|
+
"""Metadata describing a schema entry exposed to OpenAPI generation."""
|
|
55
|
+
|
|
56
|
+
name: str
|
|
57
|
+
type: Any
|
|
58
|
+
source: SchemaSource
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
_json_schema_types: dict[type, SchemaInfo] = {}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def json_schema_type(cls):
|
|
65
|
+
"""
|
|
66
|
+
Decorator to mark a Pydantic model for top-level component registration.
|
|
67
|
+
|
|
68
|
+
Models marked with this decorator will be registered as top-level components
|
|
69
|
+
in the OpenAPI schema, while unmarked models will be inlined.
|
|
70
|
+
|
|
71
|
+
This provides control over schema registration to avoid unnecessary indirection
|
|
72
|
+
for simple one-off types while keeping complex reusable types as components.
|
|
73
|
+
"""
|
|
74
|
+
cls._llama_stack_schema_type = True
|
|
75
|
+
schema_name = getattr(cls, "__name__", f"Anonymous_{id(cls)}")
|
|
76
|
+
cls._llama_stack_schema_name = schema_name
|
|
77
|
+
_json_schema_types.setdefault(cls, SchemaInfo(name=schema_name, type=cls, source="json_schema_type"))
|
|
78
|
+
return cls
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Global registries for schemas discoverable by the generator
|
|
82
|
+
_registered_schemas: dict[Any, SchemaInfo] = {}
|
|
83
|
+
_dynamic_schema_types: dict[type, SchemaInfo] = {}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def register_schema(schema_type, name: str | None = None):
|
|
87
|
+
"""
|
|
88
|
+
Register a schema type for top-level component registration.
|
|
89
|
+
|
|
90
|
+
This replicates the behavior of strong_typing's register_schema function.
|
|
91
|
+
It's used for union types and other complex types that should appear as
|
|
92
|
+
top-level components in the OpenAPI schema.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
schema_type: The type to register (e.g., union types, Annotated types)
|
|
96
|
+
name: Optional name for the schema in the OpenAPI spec. If not provided,
|
|
97
|
+
uses the type's __name__ or a generated name.
|
|
98
|
+
"""
|
|
99
|
+
if name is None:
|
|
100
|
+
name = getattr(schema_type, "__name__", f"Anonymous_{id(schema_type)}")
|
|
101
|
+
|
|
102
|
+
# Store the registration information in a global registry
|
|
103
|
+
# since union types don't allow setting attributes
|
|
104
|
+
_registered_schemas[schema_type] = SchemaInfo(name=name, type=schema_type, source="registered_schema")
|
|
105
|
+
|
|
106
|
+
return schema_type
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_registered_schema_info(schema_type: Any) -> SchemaInfo | None:
|
|
110
|
+
"""Return the registration metadata for a schema type if present."""
|
|
111
|
+
return _registered_schemas.get(schema_type)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def iter_registered_schema_types() -> Iterable[SchemaInfo]:
|
|
115
|
+
"""Iterate over all explicitly registered schema entries."""
|
|
116
|
+
return tuple(_registered_schemas.values())
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def iter_json_schema_types() -> Iterable[type]:
|
|
120
|
+
"""Iterate over all Pydantic models decorated with @json_schema_type."""
|
|
121
|
+
return tuple(info.type for info in _json_schema_types.values())
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def iter_dynamic_schema_types() -> Iterable[type]:
|
|
125
|
+
"""Iterate over dynamic models registered at generation time."""
|
|
126
|
+
return tuple(info.type for info in _dynamic_schema_types.values())
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def register_dynamic_schema_type(schema_type: type, name: str | None = None) -> type:
|
|
130
|
+
"""Register a dynamic model generated at runtime for schema inclusion."""
|
|
131
|
+
schema_name = name if name is not None else getattr(schema_type, "__name__", f"Anonymous_{id(schema_type)}")
|
|
132
|
+
_dynamic_schema_types[schema_type] = SchemaInfo(name=schema_name, type=schema_type, source="dynamic_schema")
|
|
133
|
+
return schema_type
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def clear_dynamic_schema_types() -> None:
|
|
137
|
+
"""Clear dynamic schema registrations."""
|
|
138
|
+
_dynamic_schema_types.clear()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass
|
|
142
|
+
class WebMethod:
|
|
143
|
+
level: str | None = None
|
|
144
|
+
route: str | None = None
|
|
145
|
+
public: bool = False
|
|
146
|
+
request_examples: list[Any] | None = None
|
|
147
|
+
response_examples: list[Any] | None = None
|
|
148
|
+
method: str | None = None
|
|
149
|
+
raw_bytes_request_body: bool | None = False
|
|
150
|
+
# A descriptive name of the corresponding span created by tracing
|
|
151
|
+
descriptive_name: str | None = None
|
|
152
|
+
deprecated: bool | None = False
|
|
153
|
+
require_authentication: bool | None = True
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def webmethod(
|
|
160
|
+
route: str | None = None,
|
|
161
|
+
method: str | None = None,
|
|
162
|
+
level: str | None = None,
|
|
163
|
+
public: bool | None = False,
|
|
164
|
+
request_examples: list[Any] | None = None,
|
|
165
|
+
response_examples: list[Any] | None = None,
|
|
166
|
+
raw_bytes_request_body: bool | None = False,
|
|
167
|
+
descriptive_name: str | None = None,
|
|
168
|
+
deprecated: bool | None = False,
|
|
169
|
+
require_authentication: bool | None = True,
|
|
170
|
+
) -> Callable[[CallableT], CallableT]:
|
|
171
|
+
"""
|
|
172
|
+
Decorator that supplies additional metadata to an endpoint operation function.
|
|
173
|
+
|
|
174
|
+
:param route: The URL path pattern associated with this operation which path parameters are substituted into.
|
|
175
|
+
:param public: True if the operation can be invoked without prior authentication.
|
|
176
|
+
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
|
|
177
|
+
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
|
|
178
|
+
:param require_authentication: Whether this endpoint requires authentication (default True).
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def wrap(func: CallableT) -> CallableT:
|
|
182
|
+
webmethod_obj = WebMethod(
|
|
183
|
+
route=route,
|
|
184
|
+
method=method,
|
|
185
|
+
level=level,
|
|
186
|
+
public=public or False,
|
|
187
|
+
request_examples=request_examples,
|
|
188
|
+
response_examples=response_examples,
|
|
189
|
+
raw_bytes_request_body=raw_bytes_request_body,
|
|
190
|
+
descriptive_name=descriptive_name,
|
|
191
|
+
deprecated=deprecated,
|
|
192
|
+
require_authentication=require_authentication if require_authentication is not None else True,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Store all webmethods in a list to support multiple decorators
|
|
196
|
+
if not hasattr(func, "__webmethods__"):
|
|
197
|
+
func.__webmethods__ = [] # type: ignore
|
|
198
|
+
func.__webmethods__.append(webmethod_obj) # type: ignore
|
|
199
|
+
|
|
200
|
+
# Keep the last one as __webmethod__ for backwards compatibility
|
|
201
|
+
func.__webmethod__ = webmethod_obj # type: ignore
|
|
202
|
+
return func
|
|
203
|
+
|
|
204
|
+
return wrap
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def remove_null_from_anyof(schema: dict, *, add_nullable: bool = False) -> None:
|
|
208
|
+
"""Remove null type from anyOf and optionally add nullable flag.
|
|
209
|
+
|
|
210
|
+
Converts Pydantic's default OpenAPI 3.1 style:
|
|
211
|
+
anyOf: [{type: X, enum: [...]}, {type: null}]
|
|
212
|
+
|
|
213
|
+
To flattened format:
|
|
214
|
+
type: X
|
|
215
|
+
enum: [...]
|
|
216
|
+
nullable: true # only if add_nullable=True
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
schema: The JSON schema dict to modify in-place
|
|
220
|
+
add_nullable: If True, adds 'nullable: true' when null was present.
|
|
221
|
+
Use True for OpenAPI 3.0 compatibility with OpenAI's spec.
|
|
222
|
+
"""
|
|
223
|
+
# Handle anyOf format: anyOf: [{type: string, enum: [...]}, {type: null}]
|
|
224
|
+
if "anyOf" in schema:
|
|
225
|
+
non_null = [s for s in schema["anyOf"] if s.get("type") != "null"]
|
|
226
|
+
has_null = len(non_null) < len(schema["anyOf"])
|
|
227
|
+
|
|
228
|
+
if len(non_null) == 1:
|
|
229
|
+
# Flatten to single type
|
|
230
|
+
only_schema = non_null[0]
|
|
231
|
+
schema.pop("anyOf")
|
|
232
|
+
schema.update(only_schema)
|
|
233
|
+
if has_null and add_nullable:
|
|
234
|
+
schema["nullable"] = True
|
|
235
|
+
|
|
236
|
+
# Handle OpenAPI 3.1 format: type: ['string', 'null']
|
|
237
|
+
elif isinstance(schema.get("type"), list) and "null" in schema["type"]:
|
|
238
|
+
has_null = "null" in schema["type"]
|
|
239
|
+
schema["type"].remove("null")
|
|
240
|
+
if len(schema["type"]) == 1:
|
|
241
|
+
schema["type"] = schema["type"][0]
|
|
242
|
+
if has_null and add_nullable:
|
|
243
|
+
schema["nullable"] = True
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def nullable_openai_style(schema: dict) -> None:
|
|
247
|
+
"""Shorthand for remove_null_from_anyof with add_nullable=True.
|
|
248
|
+
|
|
249
|
+
Use this for fields that need OpenAPI 3.0 nullable style to match OpenAI's spec.
|
|
250
|
+
"""
|
|
251
|
+
remove_null_from_anyof(schema, add_nullable=True)
|