llama-stack-api 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack_api/__init__.py +945 -0
- llama_stack_api/admin/__init__.py +45 -0
- llama_stack_api/admin/api.py +72 -0
- llama_stack_api/admin/fastapi_routes.py +117 -0
- llama_stack_api/admin/models.py +113 -0
- llama_stack_api/agents.py +173 -0
- llama_stack_api/batches/__init__.py +40 -0
- llama_stack_api/batches/api.py +53 -0
- llama_stack_api/batches/fastapi_routes.py +113 -0
- llama_stack_api/batches/models.py +78 -0
- llama_stack_api/benchmarks/__init__.py +43 -0
- llama_stack_api/benchmarks/api.py +39 -0
- llama_stack_api/benchmarks/fastapi_routes.py +109 -0
- llama_stack_api/benchmarks/models.py +109 -0
- llama_stack_api/common/__init__.py +5 -0
- llama_stack_api/common/content_types.py +101 -0
- llama_stack_api/common/errors.py +95 -0
- llama_stack_api/common/job_types.py +38 -0
- llama_stack_api/common/responses.py +77 -0
- llama_stack_api/common/training_types.py +47 -0
- llama_stack_api/common/type_system.py +146 -0
- llama_stack_api/connectors.py +146 -0
- llama_stack_api/conversations.py +270 -0
- llama_stack_api/datasetio.py +55 -0
- llama_stack_api/datasets/__init__.py +61 -0
- llama_stack_api/datasets/api.py +35 -0
- llama_stack_api/datasets/fastapi_routes.py +104 -0
- llama_stack_api/datasets/models.py +152 -0
- llama_stack_api/datatypes.py +373 -0
- llama_stack_api/eval.py +137 -0
- llama_stack_api/file_processors/__init__.py +27 -0
- llama_stack_api/file_processors/api.py +64 -0
- llama_stack_api/file_processors/fastapi_routes.py +78 -0
- llama_stack_api/file_processors/models.py +42 -0
- llama_stack_api/files/__init__.py +35 -0
- llama_stack_api/files/api.py +51 -0
- llama_stack_api/files/fastapi_routes.py +124 -0
- llama_stack_api/files/models.py +107 -0
- llama_stack_api/inference.py +1169 -0
- llama_stack_api/inspect_api/__init__.py +37 -0
- llama_stack_api/inspect_api/api.py +25 -0
- llama_stack_api/inspect_api/fastapi_routes.py +76 -0
- llama_stack_api/inspect_api/models.py +28 -0
- llama_stack_api/internal/__init__.py +9 -0
- llama_stack_api/internal/kvstore.py +28 -0
- llama_stack_api/internal/sqlstore.py +81 -0
- llama_stack_api/models.py +171 -0
- llama_stack_api/openai_responses.py +1468 -0
- llama_stack_api/post_training.py +370 -0
- llama_stack_api/prompts.py +203 -0
- llama_stack_api/providers/__init__.py +33 -0
- llama_stack_api/providers/api.py +16 -0
- llama_stack_api/providers/fastapi_routes.py +57 -0
- llama_stack_api/providers/models.py +24 -0
- llama_stack_api/rag_tool.py +168 -0
- llama_stack_api/resource.py +37 -0
- llama_stack_api/router_utils.py +160 -0
- llama_stack_api/safety.py +132 -0
- llama_stack_api/schema_utils.py +208 -0
- llama_stack_api/scoring.py +93 -0
- llama_stack_api/scoring_functions.py +211 -0
- llama_stack_api/shields.py +93 -0
- llama_stack_api/tools.py +226 -0
- llama_stack_api/vector_io.py +941 -0
- llama_stack_api/vector_stores.py +53 -0
- llama_stack_api/version.py +9 -0
- {llama_stack_api-0.4.2.dist-info → llama_stack_api-0.4.4.dist-info}/METADATA +1 -1
- llama_stack_api-0.4.4.dist-info/RECORD +70 -0
- {llama_stack_api-0.4.2.dist-info → llama_stack_api-0.4.4.dist-info}/WHEEL +1 -1
- llama_stack_api-0.4.4.dist-info/top_level.txt +1 -0
- llama_stack_api-0.4.2.dist-info/RECORD +0 -4
- llama_stack_api-0.4.2.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from collections.abc import Callable, Iterable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, Literal, TypeVar
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExtraBodyField[T]:
|
|
13
|
+
"""
|
|
14
|
+
Marker annotation for parameters that arrive via extra_body in the client SDK.
|
|
15
|
+
|
|
16
|
+
These parameters:
|
|
17
|
+
- Will NOT appear in the generated client SDK method signature
|
|
18
|
+
- WILL be documented in OpenAPI spec under x-llama-stack-extra-body-params
|
|
19
|
+
- MUST be passed via the extra_body parameter in client SDK calls
|
|
20
|
+
- WILL be available in server-side method signature with proper typing
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
```python
|
|
24
|
+
async def create_openai_response(
|
|
25
|
+
self,
|
|
26
|
+
input: str,
|
|
27
|
+
model: str,
|
|
28
|
+
shields: Annotated[
|
|
29
|
+
list[str] | None, ExtraBodyField("List of shields to apply")
|
|
30
|
+
] = None,
|
|
31
|
+
) -> ResponseObject:
|
|
32
|
+
# shields is available here with proper typing
|
|
33
|
+
if shields:
|
|
34
|
+
print(f"Using shields: {shields}")
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Client usage:
|
|
38
|
+
```python
|
|
39
|
+
client.responses.create(
|
|
40
|
+
input="hello", model="llama-3", extra_body={"shields": ["shield-1"]}
|
|
41
|
+
)
|
|
42
|
+
```
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, description: str | None = None):
|
|
46
|
+
self.description = description
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
SchemaSource = Literal["json_schema_type", "registered_schema", "dynamic_schema"]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class SchemaInfo:
|
|
54
|
+
"""Metadata describing a schema entry exposed to OpenAPI generation."""
|
|
55
|
+
|
|
56
|
+
name: str
|
|
57
|
+
type: Any
|
|
58
|
+
source: SchemaSource
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
_json_schema_types: dict[type, SchemaInfo] = {}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def json_schema_type(cls):
|
|
65
|
+
"""
|
|
66
|
+
Decorator to mark a Pydantic model for top-level component registration.
|
|
67
|
+
|
|
68
|
+
Models marked with this decorator will be registered as top-level components
|
|
69
|
+
in the OpenAPI schema, while unmarked models will be inlined.
|
|
70
|
+
|
|
71
|
+
This provides control over schema registration to avoid unnecessary indirection
|
|
72
|
+
for simple one-off types while keeping complex reusable types as components.
|
|
73
|
+
"""
|
|
74
|
+
cls._llama_stack_schema_type = True
|
|
75
|
+
schema_name = getattr(cls, "__name__", f"Anonymous_{id(cls)}")
|
|
76
|
+
cls._llama_stack_schema_name = schema_name
|
|
77
|
+
_json_schema_types.setdefault(cls, SchemaInfo(name=schema_name, type=cls, source="json_schema_type"))
|
|
78
|
+
return cls
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Global registries for schemas discoverable by the generator
|
|
82
|
+
_registered_schemas: dict[Any, SchemaInfo] = {}
|
|
83
|
+
_dynamic_schema_types: dict[type, SchemaInfo] = {}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def register_schema(schema_type, name: str | None = None):
|
|
87
|
+
"""
|
|
88
|
+
Register a schema type for top-level component registration.
|
|
89
|
+
|
|
90
|
+
This replicates the behavior of strong_typing's register_schema function.
|
|
91
|
+
It's used for union types and other complex types that should appear as
|
|
92
|
+
top-level components in the OpenAPI schema.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
schema_type: The type to register (e.g., union types, Annotated types)
|
|
96
|
+
name: Optional name for the schema in the OpenAPI spec. If not provided,
|
|
97
|
+
uses the type's __name__ or a generated name.
|
|
98
|
+
"""
|
|
99
|
+
if name is None:
|
|
100
|
+
name = getattr(schema_type, "__name__", f"Anonymous_{id(schema_type)}")
|
|
101
|
+
|
|
102
|
+
# Store the registration information in a global registry
|
|
103
|
+
# since union types don't allow setting attributes
|
|
104
|
+
_registered_schemas[schema_type] = SchemaInfo(name=name, type=schema_type, source="registered_schema")
|
|
105
|
+
|
|
106
|
+
return schema_type
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def get_registered_schema_info(schema_type: Any) -> SchemaInfo | None:
|
|
110
|
+
"""Return the registration metadata for a schema type if present."""
|
|
111
|
+
return _registered_schemas.get(schema_type)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def iter_registered_schema_types() -> Iterable[SchemaInfo]:
|
|
115
|
+
"""Iterate over all explicitly registered schema entries."""
|
|
116
|
+
return tuple(_registered_schemas.values())
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def iter_json_schema_types() -> Iterable[type]:
|
|
120
|
+
"""Iterate over all Pydantic models decorated with @json_schema_type."""
|
|
121
|
+
return tuple(info.type for info in _json_schema_types.values())
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def iter_dynamic_schema_types() -> Iterable[type]:
|
|
125
|
+
"""Iterate over dynamic models registered at generation time."""
|
|
126
|
+
return tuple(info.type for info in _dynamic_schema_types.values())
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def register_dynamic_schema_type(schema_type: type, name: str | None = None) -> type:
|
|
130
|
+
"""Register a dynamic model generated at runtime for schema inclusion."""
|
|
131
|
+
schema_name = name if name is not None else getattr(schema_type, "__name__", f"Anonymous_{id(schema_type)}")
|
|
132
|
+
_dynamic_schema_types[schema_type] = SchemaInfo(name=schema_name, type=schema_type, source="dynamic_schema")
|
|
133
|
+
return schema_type
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def clear_dynamic_schema_types() -> None:
|
|
137
|
+
"""Clear dynamic schema registrations."""
|
|
138
|
+
_dynamic_schema_types.clear()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass
|
|
142
|
+
class WebMethod:
|
|
143
|
+
level: str | None = None
|
|
144
|
+
route: str | None = None
|
|
145
|
+
public: bool = False
|
|
146
|
+
request_examples: list[Any] | None = None
|
|
147
|
+
response_examples: list[Any] | None = None
|
|
148
|
+
method: str | None = None
|
|
149
|
+
raw_bytes_request_body: bool | None = False
|
|
150
|
+
# A descriptive name of the corresponding span created by tracing
|
|
151
|
+
descriptive_name: str | None = None
|
|
152
|
+
required_scope: str | None = None
|
|
153
|
+
deprecated: bool | None = False
|
|
154
|
+
require_authentication: bool | None = True
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def webmethod(
|
|
161
|
+
route: str | None = None,
|
|
162
|
+
method: str | None = None,
|
|
163
|
+
level: str | None = None,
|
|
164
|
+
public: bool | None = False,
|
|
165
|
+
request_examples: list[Any] | None = None,
|
|
166
|
+
response_examples: list[Any] | None = None,
|
|
167
|
+
raw_bytes_request_body: bool | None = False,
|
|
168
|
+
descriptive_name: str | None = None,
|
|
169
|
+
required_scope: str | None = None,
|
|
170
|
+
deprecated: bool | None = False,
|
|
171
|
+
require_authentication: bool | None = True,
|
|
172
|
+
) -> Callable[[CallableT], CallableT]:
|
|
173
|
+
"""
|
|
174
|
+
Decorator that supplies additional metadata to an endpoint operation function.
|
|
175
|
+
|
|
176
|
+
:param route: The URL path pattern associated with this operation which path parameters are substituted into.
|
|
177
|
+
:param public: True if the operation can be invoked without prior authentication.
|
|
178
|
+
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
|
|
179
|
+
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
|
|
180
|
+
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
|
|
181
|
+
:param require_authentication: Whether this endpoint requires authentication (default True).
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def wrap(func: CallableT) -> CallableT:
|
|
185
|
+
webmethod_obj = WebMethod(
|
|
186
|
+
route=route,
|
|
187
|
+
method=method,
|
|
188
|
+
level=level,
|
|
189
|
+
public=public or False,
|
|
190
|
+
request_examples=request_examples,
|
|
191
|
+
response_examples=response_examples,
|
|
192
|
+
raw_bytes_request_body=raw_bytes_request_body,
|
|
193
|
+
descriptive_name=descriptive_name,
|
|
194
|
+
required_scope=required_scope,
|
|
195
|
+
deprecated=deprecated,
|
|
196
|
+
require_authentication=require_authentication if require_authentication is not None else True,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Store all webmethods in a list to support multiple decorators
|
|
200
|
+
if not hasattr(func, "__webmethods__"):
|
|
201
|
+
func.__webmethods__ = [] # type: ignore
|
|
202
|
+
func.__webmethods__.append(webmethod_obj) # type: ignore
|
|
203
|
+
|
|
204
|
+
# Keep the last one as __webmethod__ for backwards compatibility
|
|
205
|
+
func.__webmethod__ = webmethod_obj # type: ignore
|
|
206
|
+
return func
|
|
207
|
+
|
|
208
|
+
return wrap
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Any, Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
|
12
|
+
from llama_stack_api.scoring_functions import ScoringFn, ScoringFnParams
|
|
13
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1
|
|
14
|
+
|
|
15
|
+
# mapping of metric to value
|
|
16
|
+
ScoringResultRow = dict[str, Any]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@json_schema_type
|
|
20
|
+
class ScoringResult(BaseModel):
|
|
21
|
+
"""
|
|
22
|
+
A scoring result for a single row.
|
|
23
|
+
|
|
24
|
+
:param score_rows: The scoring result for each row. Each row is a map of column name to value.
|
|
25
|
+
:param aggregated_results: Map of metric name to aggregated value
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
score_rows: list[ScoringResultRow]
|
|
29
|
+
# aggregated metrics to value
|
|
30
|
+
aggregated_results: dict[str, Any]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@json_schema_type
|
|
34
|
+
class ScoreBatchResponse(BaseModel):
|
|
35
|
+
"""Response from batch scoring operations on datasets.
|
|
36
|
+
|
|
37
|
+
:param dataset_id: (Optional) The identifier of the dataset that was scored
|
|
38
|
+
:param results: A map of scoring function name to ScoringResult
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
dataset_id: str | None = None
|
|
42
|
+
results: dict[str, ScoringResult]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@json_schema_type
|
|
46
|
+
class ScoreResponse(BaseModel):
|
|
47
|
+
"""
|
|
48
|
+
The response from scoring.
|
|
49
|
+
|
|
50
|
+
:param results: A map of scoring function name to ScoringResult.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
# each key in the dict is a scoring function name
|
|
54
|
+
results: dict[str, ScoringResult]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ScoringFunctionStore(Protocol):
|
|
58
|
+
def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@runtime_checkable
|
|
62
|
+
class Scoring(Protocol):
|
|
63
|
+
scoring_function_store: ScoringFunctionStore
|
|
64
|
+
|
|
65
|
+
@webmethod(route="/scoring/score-batch", method="POST", level=LLAMA_STACK_API_V1)
|
|
66
|
+
async def score_batch(
|
|
67
|
+
self,
|
|
68
|
+
dataset_id: str,
|
|
69
|
+
scoring_functions: dict[str, ScoringFnParams | None],
|
|
70
|
+
save_results_dataset: bool = False,
|
|
71
|
+
) -> ScoreBatchResponse:
|
|
72
|
+
"""Score a batch of rows.
|
|
73
|
+
|
|
74
|
+
:param dataset_id: The ID of the dataset to score.
|
|
75
|
+
:param scoring_functions: The scoring functions to use for the scoring.
|
|
76
|
+
:param save_results_dataset: Whether to save the results to a dataset.
|
|
77
|
+
:returns: A ScoreBatchResponse.
|
|
78
|
+
"""
|
|
79
|
+
...
|
|
80
|
+
|
|
81
|
+
@webmethod(route="/scoring/score", method="POST", level=LLAMA_STACK_API_V1)
|
|
82
|
+
async def score(
|
|
83
|
+
self,
|
|
84
|
+
input_rows: list[dict[str, Any]],
|
|
85
|
+
scoring_functions: dict[str, ScoringFnParams | None],
|
|
86
|
+
) -> ScoreResponse:
|
|
87
|
+
"""Score a list of rows.
|
|
88
|
+
|
|
89
|
+
:param input_rows: The rows to score.
|
|
90
|
+
:param scoring_functions: The scoring functions to use for the scoring.
|
|
91
|
+
:returns: A ScoreResponse object containing rows and aggregated results.
|
|
92
|
+
"""
|
|
93
|
+
...
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# TODO: use enum.StrEnum when we drop support for python 3.10
|
|
8
|
+
from enum import StrEnum
|
|
9
|
+
from typing import (
|
|
10
|
+
Annotated,
|
|
11
|
+
Any,
|
|
12
|
+
Literal,
|
|
13
|
+
Protocol,
|
|
14
|
+
runtime_checkable,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from pydantic import BaseModel, Field
|
|
18
|
+
|
|
19
|
+
from llama_stack_api.common.type_system import ParamType
|
|
20
|
+
from llama_stack_api.resource import Resource, ResourceType
|
|
21
|
+
from llama_stack_api.schema_utils import json_schema_type, register_schema, webmethod
|
|
22
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
|
|
26
|
+
# with standard metrics so they can be rolled up?
|
|
27
|
+
@json_schema_type
|
|
28
|
+
class ScoringFnParamsType(StrEnum):
|
|
29
|
+
"""Types of scoring function parameter configurations.
|
|
30
|
+
:cvar llm_as_judge: Use an LLM model to evaluate and score responses
|
|
31
|
+
:cvar regex_parser: Use regex patterns to extract and score specific parts of responses
|
|
32
|
+
:cvar basic: Basic scoring with simple aggregation functions
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
llm_as_judge = "llm_as_judge"
|
|
36
|
+
regex_parser = "regex_parser"
|
|
37
|
+
basic = "basic"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@json_schema_type
|
|
41
|
+
class AggregationFunctionType(StrEnum):
|
|
42
|
+
"""Types of aggregation functions for scoring results.
|
|
43
|
+
:cvar average: Calculate the arithmetic mean of scores
|
|
44
|
+
:cvar weighted_average: Calculate a weighted average of scores
|
|
45
|
+
:cvar median: Calculate the median value of scores
|
|
46
|
+
:cvar categorical_count: Count occurrences of categorical values
|
|
47
|
+
:cvar accuracy: Calculate accuracy as the proportion of correct answers
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
average = "average"
|
|
51
|
+
weighted_average = "weighted_average"
|
|
52
|
+
median = "median"
|
|
53
|
+
categorical_count = "categorical_count"
|
|
54
|
+
accuracy = "accuracy"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@json_schema_type
|
|
58
|
+
class LLMAsJudgeScoringFnParams(BaseModel):
|
|
59
|
+
"""Parameters for LLM-as-judge scoring function configuration.
|
|
60
|
+
:param type: The type of scoring function parameters, always llm_as_judge
|
|
61
|
+
:param judge_model: Identifier of the LLM model to use as a judge for scoring
|
|
62
|
+
:param prompt_template: (Optional) Custom prompt template for the judge model
|
|
63
|
+
:param judge_score_regexes: Regexes to extract the answer from generated response
|
|
64
|
+
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
|
|
68
|
+
judge_model: str
|
|
69
|
+
prompt_template: str | None = None
|
|
70
|
+
judge_score_regexes: list[str] = Field(
|
|
71
|
+
description="Regexes to extract the answer from generated response",
|
|
72
|
+
default_factory=lambda: [],
|
|
73
|
+
)
|
|
74
|
+
aggregation_functions: list[AggregationFunctionType] = Field(
|
|
75
|
+
description="Aggregation functions to apply to the scores of each row",
|
|
76
|
+
default_factory=lambda: [],
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@json_schema_type
|
|
81
|
+
class RegexParserScoringFnParams(BaseModel):
|
|
82
|
+
"""Parameters for regex parser scoring function configuration.
|
|
83
|
+
:param type: The type of scoring function parameters, always regex_parser
|
|
84
|
+
:param parsing_regexes: Regex to extract the answer from generated response
|
|
85
|
+
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
|
|
89
|
+
parsing_regexes: list[str] = Field(
|
|
90
|
+
description="Regex to extract the answer from generated response",
|
|
91
|
+
default_factory=lambda: [],
|
|
92
|
+
)
|
|
93
|
+
aggregation_functions: list[AggregationFunctionType] = Field(
|
|
94
|
+
description="Aggregation functions to apply to the scores of each row",
|
|
95
|
+
default_factory=lambda: [],
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@json_schema_type
|
|
100
|
+
class BasicScoringFnParams(BaseModel):
|
|
101
|
+
"""Parameters for basic scoring function configuration.
|
|
102
|
+
:param type: The type of scoring function parameters, always basic
|
|
103
|
+
:param aggregation_functions: Aggregation functions to apply to the scores of each row
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
|
|
107
|
+
aggregation_functions: list[AggregationFunctionType] = Field(
|
|
108
|
+
description="Aggregation functions to apply to the scores of each row",
|
|
109
|
+
default_factory=list,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
ScoringFnParams = Annotated[
|
|
114
|
+
LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
|
|
115
|
+
Field(discriminator="type"),
|
|
116
|
+
]
|
|
117
|
+
register_schema(ScoringFnParams, name="ScoringFnParams")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class CommonScoringFnFields(BaseModel):
|
|
121
|
+
description: str | None = None
|
|
122
|
+
metadata: dict[str, Any] = Field(
|
|
123
|
+
default_factory=dict,
|
|
124
|
+
description="Any additional metadata for this definition",
|
|
125
|
+
)
|
|
126
|
+
return_type: ParamType = Field(
|
|
127
|
+
description="The return type of the deterministic function",
|
|
128
|
+
)
|
|
129
|
+
params: ScoringFnParams | None = Field(
|
|
130
|
+
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
|
131
|
+
default=None,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@json_schema_type
|
|
136
|
+
class ScoringFn(CommonScoringFnFields, Resource):
|
|
137
|
+
"""A scoring function resource for evaluating model outputs.
|
|
138
|
+
:param type: The resource type, always scoring_function
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def scoring_fn_id(self) -> str:
|
|
145
|
+
return self.identifier
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def provider_scoring_fn_id(self) -> str | None:
|
|
149
|
+
return self.provider_resource_id
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class ScoringFnInput(CommonScoringFnFields, BaseModel):
|
|
153
|
+
scoring_fn_id: str
|
|
154
|
+
provider_id: str | None = None
|
|
155
|
+
provider_scoring_fn_id: str | None = None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@json_schema_type
|
|
159
|
+
class ListScoringFunctionsResponse(BaseModel):
|
|
160
|
+
data: list[ScoringFn]
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@runtime_checkable
|
|
164
|
+
class ScoringFunctions(Protocol):
|
|
165
|
+
@webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
|
|
166
|
+
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
167
|
+
"""List all scoring functions.
|
|
168
|
+
|
|
169
|
+
:returns: A ListScoringFunctionsResponse.
|
|
170
|
+
"""
|
|
171
|
+
...
|
|
172
|
+
|
|
173
|
+
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
|
174
|
+
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
|
175
|
+
"""Get a scoring function by its ID.
|
|
176
|
+
|
|
177
|
+
:param scoring_fn_id: The ID of the scoring function to get.
|
|
178
|
+
:returns: A ScoringFn.
|
|
179
|
+
"""
|
|
180
|
+
...
|
|
181
|
+
|
|
182
|
+
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
183
|
+
async def register_scoring_function(
|
|
184
|
+
self,
|
|
185
|
+
scoring_fn_id: str,
|
|
186
|
+
description: str,
|
|
187
|
+
return_type: ParamType,
|
|
188
|
+
provider_scoring_fn_id: str | None = None,
|
|
189
|
+
provider_id: str | None = None,
|
|
190
|
+
params: ScoringFnParams | None = None,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""Register a scoring function.
|
|
193
|
+
|
|
194
|
+
:param scoring_fn_id: The ID of the scoring function to register.
|
|
195
|
+
:param description: The description of the scoring function.
|
|
196
|
+
:param return_type: The return type of the scoring function.
|
|
197
|
+
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
|
|
198
|
+
:param provider_id: The ID of the provider to use for the scoring function.
|
|
199
|
+
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
|
200
|
+
"""
|
|
201
|
+
...
|
|
202
|
+
|
|
203
|
+
@webmethod(
|
|
204
|
+
route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
|
|
205
|
+
)
|
|
206
|
+
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
|
207
|
+
"""Unregister a scoring function.
|
|
208
|
+
|
|
209
|
+
:param scoring_fn_id: The ID of the scoring function to unregister.
|
|
210
|
+
"""
|
|
211
|
+
...
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Any, Literal, Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from llama_stack_api.resource import Resource, ResourceType
|
|
12
|
+
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
|
13
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CommonShieldFields(BaseModel):
|
|
17
|
+
params: dict[str, Any] | None = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@json_schema_type
|
|
21
|
+
class Shield(CommonShieldFields, Resource):
|
|
22
|
+
"""A safety shield resource that can be used to check content.
|
|
23
|
+
|
|
24
|
+
:param params: (Optional) Configuration parameters for the shield
|
|
25
|
+
:param type: The resource type, always shield
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
type: Literal[ResourceType.shield] = ResourceType.shield
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def shield_id(self) -> str:
|
|
32
|
+
return self.identifier
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def provider_shield_id(self) -> str | None:
|
|
36
|
+
return self.provider_resource_id
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ShieldInput(CommonShieldFields):
|
|
40
|
+
shield_id: str
|
|
41
|
+
provider_id: str | None = None
|
|
42
|
+
provider_shield_id: str | None = None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@json_schema_type
|
|
46
|
+
class ListShieldsResponse(BaseModel):
|
|
47
|
+
data: list[Shield]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@runtime_checkable
|
|
51
|
+
class Shields(Protocol):
|
|
52
|
+
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
|
|
53
|
+
async def list_shields(self) -> ListShieldsResponse:
|
|
54
|
+
"""List all shields.
|
|
55
|
+
|
|
56
|
+
:returns: A ListShieldsResponse.
|
|
57
|
+
"""
|
|
58
|
+
...
|
|
59
|
+
|
|
60
|
+
@webmethod(route="/shields/{identifier:path}", method="GET", level=LLAMA_STACK_API_V1)
|
|
61
|
+
async def get_shield(self, identifier: str) -> Shield:
|
|
62
|
+
"""Get a shield by its identifier.
|
|
63
|
+
|
|
64
|
+
:param identifier: The identifier of the shield to get.
|
|
65
|
+
:returns: A Shield.
|
|
66
|
+
"""
|
|
67
|
+
...
|
|
68
|
+
|
|
69
|
+
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
70
|
+
async def register_shield(
|
|
71
|
+
self,
|
|
72
|
+
shield_id: str,
|
|
73
|
+
provider_shield_id: str | None = None,
|
|
74
|
+
provider_id: str | None = None,
|
|
75
|
+
params: dict[str, Any] | None = None,
|
|
76
|
+
) -> Shield:
|
|
77
|
+
"""Register a shield.
|
|
78
|
+
|
|
79
|
+
:param shield_id: The identifier of the shield to register.
|
|
80
|
+
:param provider_shield_id: The identifier of the shield in the provider.
|
|
81
|
+
:param provider_id: The identifier of the provider.
|
|
82
|
+
:param params: The parameters of the shield.
|
|
83
|
+
:returns: A Shield.
|
|
84
|
+
"""
|
|
85
|
+
...
|
|
86
|
+
|
|
87
|
+
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
|
88
|
+
async def unregister_shield(self, identifier: str) -> None:
|
|
89
|
+
"""Unregister a shield.
|
|
90
|
+
|
|
91
|
+
:param identifier: The identifier of the shield to unregister.
|
|
92
|
+
"""
|
|
93
|
+
...
|