llama-stack-api 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. llama_stack_api/__init__.py +175 -20
  2. llama_stack_api/agents/__init__.py +38 -0
  3. llama_stack_api/agents/api.py +52 -0
  4. llama_stack_api/agents/fastapi_routes.py +268 -0
  5. llama_stack_api/agents/models.py +181 -0
  6. llama_stack_api/common/errors.py +15 -0
  7. llama_stack_api/connectors/__init__.py +38 -0
  8. llama_stack_api/connectors/api.py +50 -0
  9. llama_stack_api/connectors/fastapi_routes.py +103 -0
  10. llama_stack_api/connectors/models.py +103 -0
  11. llama_stack_api/conversations/__init__.py +61 -0
  12. llama_stack_api/conversations/api.py +44 -0
  13. llama_stack_api/conversations/fastapi_routes.py +177 -0
  14. llama_stack_api/conversations/models.py +245 -0
  15. llama_stack_api/datasetio/__init__.py +34 -0
  16. llama_stack_api/datasetio/api.py +42 -0
  17. llama_stack_api/datasetio/fastapi_routes.py +94 -0
  18. llama_stack_api/datasetio/models.py +48 -0
  19. llama_stack_api/eval/__init__.py +55 -0
  20. llama_stack_api/eval/api.py +51 -0
  21. llama_stack_api/eval/compat.py +300 -0
  22. llama_stack_api/eval/fastapi_routes.py +126 -0
  23. llama_stack_api/eval/models.py +141 -0
  24. llama_stack_api/inference/__init__.py +207 -0
  25. llama_stack_api/inference/api.py +93 -0
  26. llama_stack_api/inference/fastapi_routes.py +243 -0
  27. llama_stack_api/inference/models.py +1035 -0
  28. llama_stack_api/models/__init__.py +47 -0
  29. llama_stack_api/models/api.py +38 -0
  30. llama_stack_api/models/fastapi_routes.py +104 -0
  31. llama_stack_api/{models.py → models/models.py} +65 -79
  32. llama_stack_api/openai_responses.py +32 -6
  33. llama_stack_api/post_training/__init__.py +73 -0
  34. llama_stack_api/post_training/api.py +36 -0
  35. llama_stack_api/post_training/fastapi_routes.py +116 -0
  36. llama_stack_api/{post_training.py → post_training/models.py} +55 -86
  37. llama_stack_api/prompts/__init__.py +47 -0
  38. llama_stack_api/prompts/api.py +44 -0
  39. llama_stack_api/prompts/fastapi_routes.py +163 -0
  40. llama_stack_api/prompts/models.py +177 -0
  41. llama_stack_api/resource.py +0 -1
  42. llama_stack_api/safety/__init__.py +37 -0
  43. llama_stack_api/safety/api.py +29 -0
  44. llama_stack_api/safety/datatypes.py +83 -0
  45. llama_stack_api/safety/fastapi_routes.py +55 -0
  46. llama_stack_api/safety/models.py +38 -0
  47. llama_stack_api/schema_utils.py +47 -4
  48. llama_stack_api/scoring/__init__.py +66 -0
  49. llama_stack_api/scoring/api.py +35 -0
  50. llama_stack_api/scoring/fastapi_routes.py +67 -0
  51. llama_stack_api/scoring/models.py +81 -0
  52. llama_stack_api/scoring_functions/__init__.py +50 -0
  53. llama_stack_api/scoring_functions/api.py +39 -0
  54. llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
  55. llama_stack_api/{scoring_functions.py → scoring_functions/models.py} +67 -64
  56. llama_stack_api/shields/__init__.py +41 -0
  57. llama_stack_api/shields/api.py +39 -0
  58. llama_stack_api/shields/fastapi_routes.py +104 -0
  59. llama_stack_api/shields/models.py +74 -0
  60. llama_stack_api/validators.py +46 -0
  61. llama_stack_api/vector_io/__init__.py +88 -0
  62. llama_stack_api/vector_io/api.py +234 -0
  63. llama_stack_api/vector_io/fastapi_routes.py +447 -0
  64. llama_stack_api/{vector_io.py → vector_io/models.py} +99 -377
  65. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
  66. llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
  67. llama_stack_api/agents.py +0 -173
  68. llama_stack_api/connectors.py +0 -146
  69. llama_stack_api/conversations.py +0 -270
  70. llama_stack_api/datasetio.py +0 -55
  71. llama_stack_api/eval.py +0 -137
  72. llama_stack_api/inference.py +0 -1169
  73. llama_stack_api/prompts.py +0 -203
  74. llama_stack_api/safety.py +0 -132
  75. llama_stack_api/scoring.py +0 -93
  76. llama_stack_api/shields.py +0 -93
  77. llama_stack_api-0.4.4.dist-info/RECORD +0 -70
  78. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
  79. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,81 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Pydantic models for Scoring API requests and responses.
8
+
9
+ This module defines the request and response models for the Scoring API
10
+ using Pydantic with Field descriptions for OpenAPI schema generation.
11
+ """
12
+
13
+ from typing import Any
14
+
15
+ from pydantic import BaseModel, Field
16
+
17
+ from llama_stack_api.schema_utils import json_schema_type
18
+ from llama_stack_api.scoring_functions import ScoringFnParams
19
+
20
+ # mapping of metric to value
21
+ ScoringResultRow = dict[str, Any]
22
+
23
+
24
+ @json_schema_type
25
+ class ScoringResult(BaseModel):
26
+ """
27
+ A scoring result for a single row.
28
+ """
29
+
30
+ score_rows: list[ScoringResultRow] = Field(
31
+ ..., description="The scoring result for each row. Each row is a map of column name to value."
32
+ )
33
+ aggregated_results: dict[str, Any] = Field(..., description="Map of metric name to aggregated value")
34
+
35
+
36
+ @json_schema_type
37
+ class ScoreBatchResponse(BaseModel):
38
+ """Response from batch scoring operations on datasets."""
39
+
40
+ dataset_id: str | None = Field(default=None, description="(Optional) The identifier of the dataset that was scored")
41
+ results: dict[str, ScoringResult] = Field(..., description="A map of scoring function name to ScoringResult")
42
+
43
+
44
+ @json_schema_type
45
+ class ScoreResponse(BaseModel):
46
+ """
47
+ The response from scoring.
48
+ """
49
+
50
+ results: dict[str, ScoringResult] = Field(..., description="A map of scoring function name to ScoringResult.")
51
+
52
+
53
+ @json_schema_type
54
+ class ScoreRequest(BaseModel):
55
+ """Request model for scoring a list of rows."""
56
+
57
+ input_rows: list[dict[str, Any]] = Field(..., description="The rows to score.")
58
+ scoring_functions: dict[str, ScoringFnParams | None] = Field(
59
+ ..., description="The scoring functions to use for the scoring."
60
+ )
61
+
62
+
63
+ @json_schema_type
64
+ class ScoreBatchRequest(BaseModel):
65
+ """Request model for scoring a batch of rows from a dataset."""
66
+
67
+ dataset_id: str = Field(..., description="The ID of the dataset to score.")
68
+ scoring_functions: dict[str, ScoringFnParams | None] = Field(
69
+ ..., description="The scoring functions to use for the scoring."
70
+ )
71
+ save_results_dataset: bool = Field(default=False, description="Whether to save the results to a dataset.")
72
+
73
+
74
+ __all__ = [
75
+ "ScoringResult",
76
+ "ScoringResultRow",
77
+ "ScoreBatchResponse",
78
+ "ScoreResponse",
79
+ "ScoreRequest",
80
+ "ScoreBatchRequest",
81
+ ]
@@ -0,0 +1,50 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """ScoringFunctions API protocol and models.
8
+
9
+ This module contains the ScoringFunctions protocol definition.
10
+ Pydantic models are defined in llama_stack_api.scoring_functions.models.
11
+ The FastAPI router is defined in llama_stack_api.scoring_functions.fastapi_routes.
12
+ """
13
+
14
+ from . import fastapi_routes
15
+ from .api import ScoringFunctions
16
+ from .models import (
17
+ AggregationFunctionType,
18
+ BasicScoringFnParams,
19
+ CommonScoringFnFields,
20
+ GetScoringFunctionRequest,
21
+ ListScoringFunctionsRequest,
22
+ ListScoringFunctionsResponse,
23
+ LLMAsJudgeScoringFnParams,
24
+ RegexParserScoringFnParams,
25
+ RegisterScoringFunctionRequest,
26
+ ScoringFn,
27
+ ScoringFnInput,
28
+ ScoringFnParams,
29
+ ScoringFnParamsType,
30
+ UnregisterScoringFunctionRequest,
31
+ )
32
+
33
+ __all__ = [
34
+ "ScoringFunctions",
35
+ "ScoringFn",
36
+ "ScoringFnInput",
37
+ "ScoringFnParams",
38
+ "ScoringFnParamsType",
39
+ "AggregationFunctionType",
40
+ "LLMAsJudgeScoringFnParams",
41
+ "RegexParserScoringFnParams",
42
+ "BasicScoringFnParams",
43
+ "CommonScoringFnFields",
44
+ "ListScoringFunctionsResponse",
45
+ "ListScoringFunctionsRequest",
46
+ "GetScoringFunctionRequest",
47
+ "RegisterScoringFunctionRequest",
48
+ "UnregisterScoringFunctionRequest",
49
+ "fastapi_routes",
50
+ ]
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Protocol, runtime_checkable
8
+
9
+ from .models import (
10
+ GetScoringFunctionRequest,
11
+ ListScoringFunctionsRequest,
12
+ ListScoringFunctionsResponse,
13
+ RegisterScoringFunctionRequest,
14
+ ScoringFn,
15
+ UnregisterScoringFunctionRequest,
16
+ )
17
+
18
+
19
+ @runtime_checkable
20
+ class ScoringFunctions(Protocol):
21
+ async def list_scoring_functions(
22
+ self,
23
+ request: ListScoringFunctionsRequest,
24
+ ) -> ListScoringFunctionsResponse: ...
25
+
26
+ async def get_scoring_function(
27
+ self,
28
+ request: GetScoringFunctionRequest,
29
+ ) -> ScoringFn: ...
30
+
31
+ async def register_scoring_function(
32
+ self,
33
+ request: RegisterScoringFunctionRequest,
34
+ ) -> None: ...
35
+
36
+ async def unregister_scoring_function(
37
+ self,
38
+ request: UnregisterScoringFunctionRequest,
39
+ ) -> None: ...
@@ -0,0 +1,108 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """FastAPI router for the ScoringFunctions API.
8
+
9
+ This module defines the FastAPI router for the ScoringFunctions API using standard
10
+ FastAPI route decorators.
11
+
12
+ The router is defined in the API package to keep all API-related code together.
13
+ """
14
+
15
+ from typing import Annotated
16
+
17
+ from fastapi import APIRouter, Body, Depends
18
+
19
+ from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses
20
+ from llama_stack_api.version import LLAMA_STACK_API_V1
21
+
22
+ from .api import ScoringFunctions
23
+ from .models import (
24
+ GetScoringFunctionRequest,
25
+ ListScoringFunctionsRequest,
26
+ ListScoringFunctionsResponse,
27
+ RegisterScoringFunctionRequest,
28
+ ScoringFn,
29
+ UnregisterScoringFunctionRequest,
30
+ )
31
+
32
+ get_list_scoring_functions_request = create_query_dependency(ListScoringFunctionsRequest)
33
+ get_get_scoring_function_request = create_path_dependency(GetScoringFunctionRequest)
34
+ get_unregister_scoring_function_request = create_path_dependency(UnregisterScoringFunctionRequest)
35
+
36
+
37
+ def create_router(impl: ScoringFunctions) -> APIRouter:
38
+ """Create a FastAPI router for the ScoringFunctions API.
39
+
40
+ Args:
41
+ impl: The ScoringFunctions implementation instance
42
+
43
+ Returns:
44
+ APIRouter configured for the ScoringFunctions API
45
+ """
46
+ router = APIRouter(
47
+ prefix=f"/{LLAMA_STACK_API_V1}",
48
+ tags=["Scoring Functions"],
49
+ responses=standard_responses,
50
+ )
51
+
52
+ @router.get(
53
+ "/scoring-functions",
54
+ response_model=ListScoringFunctionsResponse,
55
+ summary="List all scoring functions.",
56
+ description="List all scoring functions.",
57
+ responses={
58
+ 200: {"description": "A ListScoringFunctionsResponse."},
59
+ },
60
+ )
61
+ async def list_scoring_functions(
62
+ request: Annotated[ListScoringFunctionsRequest, Depends(get_list_scoring_functions_request)],
63
+ ) -> ListScoringFunctionsResponse:
64
+ return await impl.list_scoring_functions(request)
65
+
66
+ @router.get(
67
+ "/scoring-functions/{scoring_fn_id:path}",
68
+ response_model=ScoringFn,
69
+ summary="Get a scoring function by its ID.",
70
+ description="Get a scoring function by its ID.",
71
+ responses={
72
+ 200: {"description": "A ScoringFn."},
73
+ },
74
+ )
75
+ async def get_scoring_function(
76
+ request: Annotated[GetScoringFunctionRequest, Depends(get_get_scoring_function_request)],
77
+ ) -> ScoringFn:
78
+ return await impl.get_scoring_function(request)
79
+
80
+ @router.post(
81
+ "/scoring-functions",
82
+ summary="Register a scoring function.",
83
+ description="Register a scoring function.",
84
+ responses={
85
+ 200: {"description": "The scoring function was successfully registered."},
86
+ },
87
+ deprecated=True,
88
+ )
89
+ async def register_scoring_function(
90
+ request: Annotated[RegisterScoringFunctionRequest, Body(...)],
91
+ ) -> None:
92
+ return await impl.register_scoring_function(request)
93
+
94
+ @router.delete(
95
+ "/scoring-functions/{scoring_fn_id:path}",
96
+ summary="Unregister a scoring function.",
97
+ description="Unregister a scoring function.",
98
+ responses={
99
+ 200: {"description": "The scoring function was successfully unregistered."},
100
+ },
101
+ deprecated=True,
102
+ )
103
+ async def unregister_scoring_function(
104
+ request: Annotated[UnregisterScoringFunctionRequest, Depends(get_unregister_scoring_function_request)],
105
+ ) -> None:
106
+ return await impl.unregister_scoring_function(request)
107
+
108
+ return router
@@ -4,26 +4,22 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- # TODO: use enum.StrEnum when we drop support for python 3.10
7
+ """Pydantic models for ScoringFunctions API requests and responses.
8
+
9
+ This module defines the request and response models for the ScoringFunctions API
10
+ using Pydantic with Field descriptions for OpenAPI schema generation.
11
+ """
12
+
8
13
  from enum import StrEnum
9
- from typing import (
10
- Annotated,
11
- Any,
12
- Literal,
13
- Protocol,
14
- runtime_checkable,
15
- )
14
+ from typing import Annotated, Any, Literal
16
15
 
17
16
  from pydantic import BaseModel, Field
18
17
 
19
18
  from llama_stack_api.common.type_system import ParamType
20
19
  from llama_stack_api.resource import Resource, ResourceType
21
- from llama_stack_api.schema_utils import json_schema_type, register_schema, webmethod
22
- from llama_stack_api.version import LLAMA_STACK_API_V1
20
+ from llama_stack_api.schema_utils import json_schema_type, register_schema
23
21
 
24
22
 
25
- # Perhaps more structure can be imposed on these functions. Maybe they could be associated
26
- # with standard metrics so they can be rolled up?
27
23
  @json_schema_type
28
24
  class ScoringFnParamsType(StrEnum):
29
25
  """Types of scoring function parameter configurations.
@@ -117,6 +113,44 @@ ScoringFnParams = Annotated[
117
113
  register_schema(ScoringFnParams, name="ScoringFnParams")
118
114
 
119
115
 
116
+ @json_schema_type
117
+ class ListScoringFunctionsRequest(BaseModel):
118
+ """Request model for listing scoring functions."""
119
+
120
+ pass
121
+
122
+
123
+ @json_schema_type
124
+ class GetScoringFunctionRequest(BaseModel):
125
+ """Request model for getting a scoring function."""
126
+
127
+ scoring_fn_id: str = Field(..., description="The ID of the scoring function to get.")
128
+
129
+
130
+ @json_schema_type
131
+ class RegisterScoringFunctionRequest(BaseModel):
132
+ """Request model for registering a scoring function."""
133
+
134
+ scoring_fn_id: str = Field(..., description="The ID of the scoring function to register.")
135
+ description: str = Field(..., description="The description of the scoring function.")
136
+ return_type: ParamType = Field(..., description="The return type of the scoring function.")
137
+ provider_scoring_fn_id: str | None = Field(
138
+ default=None, description="The ID of the provider scoring function to use for the scoring function."
139
+ )
140
+ provider_id: str | None = Field(default=None, description="The ID of the provider to use for the scoring function.")
141
+ params: ScoringFnParams | None = Field(
142
+ default=None,
143
+ description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval.",
144
+ )
145
+
146
+
147
+ @json_schema_type
148
+ class UnregisterScoringFunctionRequest(BaseModel):
149
+ """Request model for unregistering a scoring function."""
150
+
151
+ scoring_fn_id: str = Field(..., description="The ID of the scoring function to unregister.")
152
+
153
+
120
154
  class CommonScoringFnFields(BaseModel):
121
155
  description: str | None = None
122
156
  metadata: dict[str, Any] = Field(
@@ -157,55 +191,24 @@ class ScoringFnInput(CommonScoringFnFields, BaseModel):
157
191
 
158
192
  @json_schema_type
159
193
  class ListScoringFunctionsResponse(BaseModel):
160
- data: list[ScoringFn]
161
-
162
-
163
- @runtime_checkable
164
- class ScoringFunctions(Protocol):
165
- @webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
166
- async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
167
- """List all scoring functions.
168
-
169
- :returns: A ListScoringFunctionsResponse.
170
- """
171
- ...
172
-
173
- @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
174
- async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
175
- """Get a scoring function by its ID.
176
-
177
- :param scoring_fn_id: The ID of the scoring function to get.
178
- :returns: A ScoringFn.
179
- """
180
- ...
181
-
182
- @webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
183
- async def register_scoring_function(
184
- self,
185
- scoring_fn_id: str,
186
- description: str,
187
- return_type: ParamType,
188
- provider_scoring_fn_id: str | None = None,
189
- provider_id: str | None = None,
190
- params: ScoringFnParams | None = None,
191
- ) -> None:
192
- """Register a scoring function.
193
-
194
- :param scoring_fn_id: The ID of the scoring function to register.
195
- :param description: The description of the scoring function.
196
- :param return_type: The return type of the scoring function.
197
- :param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
198
- :param provider_id: The ID of the provider to use for the scoring function.
199
- :param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
200
- """
201
- ...
202
-
203
- @webmethod(
204
- route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
205
- )
206
- async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
207
- """Unregister a scoring function.
208
-
209
- :param scoring_fn_id: The ID of the scoring function to unregister.
210
- """
211
- ...
194
+ """Response containing a list of scoring function objects."""
195
+
196
+ data: list[ScoringFn] = Field(..., description="List of scoring function objects.")
197
+
198
+
199
+ __all__ = [
200
+ "ScoringFnParamsType",
201
+ "AggregationFunctionType",
202
+ "LLMAsJudgeScoringFnParams",
203
+ "RegexParserScoringFnParams",
204
+ "BasicScoringFnParams",
205
+ "ScoringFnParams",
206
+ "ListScoringFunctionsRequest",
207
+ "GetScoringFunctionRequest",
208
+ "RegisterScoringFunctionRequest",
209
+ "UnregisterScoringFunctionRequest",
210
+ "CommonScoringFnFields",
211
+ "ScoringFn",
212
+ "ScoringFnInput",
213
+ "ListScoringFunctionsResponse",
214
+ ]
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Shields API protocol and models.
8
+
9
+ This module contains the Shields protocol definition.
10
+ Pydantic models are defined in llama_stack_api.shields.models.
11
+ The FastAPI router is defined in llama_stack_api.shields.fastapi_routes.
12
+ """
13
+
14
+ # Import fastapi_routes for router factory access
15
+ from . import fastapi_routes
16
+
17
+ # Import protocol for re-export
18
+ from .api import Shields
19
+
20
+ # Import models for re-export
21
+ from .models import (
22
+ CommonShieldFields,
23
+ GetShieldRequest,
24
+ ListShieldsResponse,
25
+ RegisterShieldRequest,
26
+ Shield,
27
+ ShieldInput,
28
+ UnregisterShieldRequest,
29
+ )
30
+
31
+ __all__ = [
32
+ "Shields",
33
+ "Shield",
34
+ "ShieldInput",
35
+ "CommonShieldFields",
36
+ "ListShieldsResponse",
37
+ "GetShieldRequest",
38
+ "RegisterShieldRequest",
39
+ "UnregisterShieldRequest",
40
+ "fastapi_routes",
41
+ ]
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Shields API protocol definition.
8
+
9
+ This module contains the Shields protocol for managing shield resources.
10
+ """
11
+
12
+ from typing import Protocol, runtime_checkable
13
+
14
+ from .models import (
15
+ GetShieldRequest,
16
+ ListShieldsResponse,
17
+ RegisterShieldRequest,
18
+ Shield,
19
+ UnregisterShieldRequest,
20
+ )
21
+
22
+
23
+ @runtime_checkable
24
+ class Shields(Protocol):
25
+ async def list_shields(self) -> ListShieldsResponse:
26
+ """List all shields."""
27
+ ...
28
+
29
+ async def get_shield(self, request: GetShieldRequest) -> Shield:
30
+ """Get a shield by its identifier."""
31
+ ...
32
+
33
+ async def register_shield(self, request: RegisterShieldRequest) -> Shield:
34
+ """Register a shield."""
35
+ ...
36
+
37
+ async def unregister_shield(self, request: UnregisterShieldRequest) -> None:
38
+ """Unregister a shield."""
39
+ ...
@@ -0,0 +1,104 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """FastAPI router for the Shields API.
8
+
9
+ This module defines the FastAPI router for the Shields API using standard
10
+ FastAPI route decorators.
11
+ """
12
+
13
+ from typing import Annotated
14
+
15
+ from fastapi import APIRouter, Body, Depends
16
+
17
+ from llama_stack_api.router_utils import create_path_dependency, standard_responses
18
+ from llama_stack_api.version import LLAMA_STACK_API_V1
19
+
20
+ from .api import Shields
21
+ from .models import (
22
+ GetShieldRequest,
23
+ ListShieldsResponse,
24
+ RegisterShieldRequest,
25
+ Shield,
26
+ UnregisterShieldRequest,
27
+ )
28
+
29
+ # Automatically generate dependency functions from Pydantic models
30
+ get_get_shield_request = create_path_dependency(GetShieldRequest)
31
+ get_unregister_shield_request = create_path_dependency(UnregisterShieldRequest)
32
+
33
+
34
+ def create_router(impl: Shields) -> APIRouter:
35
+ """Create a FastAPI router for the Shields API.
36
+
37
+ Args:
38
+ impl: The Shields implementation instance
39
+
40
+ Returns:
41
+ APIRouter configured for the Shields API
42
+ """
43
+ router = APIRouter(
44
+ prefix=f"/{LLAMA_STACK_API_V1}",
45
+ tags=["Shields"],
46
+ responses=standard_responses,
47
+ )
48
+
49
+ @router.get(
50
+ "/shields",
51
+ response_model=ListShieldsResponse,
52
+ summary="List all shields.",
53
+ description="List all shields.",
54
+ responses={
55
+ 200: {"description": "A ListShieldsResponse."},
56
+ },
57
+ )
58
+ async def list_shields() -> ListShieldsResponse:
59
+ return await impl.list_shields()
60
+
61
+ @router.get(
62
+ "/shields/{identifier:path}",
63
+ response_model=Shield,
64
+ summary="Get a shield by its identifier.",
65
+ description="Get a shield by its identifier.",
66
+ responses={
67
+ 200: {"description": "A Shield."},
68
+ },
69
+ )
70
+ async def get_shield(
71
+ request: Annotated[GetShieldRequest, Depends(get_get_shield_request)],
72
+ ) -> Shield:
73
+ return await impl.get_shield(request)
74
+
75
+ @router.post(
76
+ "/shields",
77
+ response_model=Shield,
78
+ summary="Register a shield.",
79
+ description="Register a shield.",
80
+ responses={
81
+ 200: {"description": "A Shield."},
82
+ },
83
+ deprecated=True,
84
+ )
85
+ async def register_shield(
86
+ request: Annotated[RegisterShieldRequest, Body(...)],
87
+ ) -> Shield:
88
+ return await impl.register_shield(request)
89
+
90
+ @router.delete(
91
+ "/shields/{identifier:path}",
92
+ summary="Unregister a shield.",
93
+ description="Unregister a shield.",
94
+ responses={
95
+ 200: {"description": "The shield was successfully unregistered."},
96
+ },
97
+ deprecated=True,
98
+ )
99
+ async def unregister_shield(
100
+ request: Annotated[UnregisterShieldRequest, Depends(get_unregister_shield_request)],
101
+ ) -> None:
102
+ return await impl.unregister_shield(request)
103
+
104
+ return router