llama-stack-api 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. llama_stack_api/__init__.py +175 -20
  2. llama_stack_api/agents/__init__.py +38 -0
  3. llama_stack_api/agents/api.py +52 -0
  4. llama_stack_api/agents/fastapi_routes.py +268 -0
  5. llama_stack_api/agents/models.py +181 -0
  6. llama_stack_api/common/errors.py +15 -0
  7. llama_stack_api/connectors/__init__.py +38 -0
  8. llama_stack_api/connectors/api.py +50 -0
  9. llama_stack_api/connectors/fastapi_routes.py +103 -0
  10. llama_stack_api/connectors/models.py +103 -0
  11. llama_stack_api/conversations/__init__.py +61 -0
  12. llama_stack_api/conversations/api.py +44 -0
  13. llama_stack_api/conversations/fastapi_routes.py +177 -0
  14. llama_stack_api/conversations/models.py +245 -0
  15. llama_stack_api/datasetio/__init__.py +34 -0
  16. llama_stack_api/datasetio/api.py +42 -0
  17. llama_stack_api/datasetio/fastapi_routes.py +94 -0
  18. llama_stack_api/datasetio/models.py +48 -0
  19. llama_stack_api/eval/__init__.py +55 -0
  20. llama_stack_api/eval/api.py +51 -0
  21. llama_stack_api/eval/compat.py +300 -0
  22. llama_stack_api/eval/fastapi_routes.py +126 -0
  23. llama_stack_api/eval/models.py +141 -0
  24. llama_stack_api/inference/__init__.py +207 -0
  25. llama_stack_api/inference/api.py +93 -0
  26. llama_stack_api/inference/fastapi_routes.py +243 -0
  27. llama_stack_api/inference/models.py +1035 -0
  28. llama_stack_api/models/__init__.py +47 -0
  29. llama_stack_api/models/api.py +38 -0
  30. llama_stack_api/models/fastapi_routes.py +104 -0
  31. llama_stack_api/{models.py → models/models.py} +65 -79
  32. llama_stack_api/openai_responses.py +32 -6
  33. llama_stack_api/post_training/__init__.py +73 -0
  34. llama_stack_api/post_training/api.py +36 -0
  35. llama_stack_api/post_training/fastapi_routes.py +116 -0
  36. llama_stack_api/{post_training.py → post_training/models.py} +55 -86
  37. llama_stack_api/prompts/__init__.py +47 -0
  38. llama_stack_api/prompts/api.py +44 -0
  39. llama_stack_api/prompts/fastapi_routes.py +163 -0
  40. llama_stack_api/prompts/models.py +177 -0
  41. llama_stack_api/resource.py +0 -1
  42. llama_stack_api/safety/__init__.py +37 -0
  43. llama_stack_api/safety/api.py +29 -0
  44. llama_stack_api/safety/datatypes.py +83 -0
  45. llama_stack_api/safety/fastapi_routes.py +55 -0
  46. llama_stack_api/safety/models.py +38 -0
  47. llama_stack_api/schema_utils.py +47 -4
  48. llama_stack_api/scoring/__init__.py +66 -0
  49. llama_stack_api/scoring/api.py +35 -0
  50. llama_stack_api/scoring/fastapi_routes.py +67 -0
  51. llama_stack_api/scoring/models.py +81 -0
  52. llama_stack_api/scoring_functions/__init__.py +50 -0
  53. llama_stack_api/scoring_functions/api.py +39 -0
  54. llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
  55. llama_stack_api/{scoring_functions.py → scoring_functions/models.py} +67 -64
  56. llama_stack_api/shields/__init__.py +41 -0
  57. llama_stack_api/shields/api.py +39 -0
  58. llama_stack_api/shields/fastapi_routes.py +104 -0
  59. llama_stack_api/shields/models.py +74 -0
  60. llama_stack_api/validators.py +46 -0
  61. llama_stack_api/vector_io/__init__.py +88 -0
  62. llama_stack_api/vector_io/api.py +234 -0
  63. llama_stack_api/vector_io/fastapi_routes.py +447 -0
  64. llama_stack_api/{vector_io.py → vector_io/models.py} +99 -377
  65. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
  66. llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
  67. llama_stack_api/agents.py +0 -173
  68. llama_stack_api/connectors.py +0 -146
  69. llama_stack_api/conversations.py +0 -270
  70. llama_stack_api/datasetio.py +0 -55
  71. llama_stack_api/eval.py +0 -137
  72. llama_stack_api/inference.py +0 -1169
  73. llama_stack_api/prompts.py +0 -203
  74. llama_stack_api/safety.py +0 -132
  75. llama_stack_api/scoring.py +0 -93
  76. llama_stack_api/shields.py +0 -93
  77. llama_stack_api-0.4.4.dist-info/RECORD +0 -70
  78. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
  79. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Pydantic models for Prompts API requests and responses.
8
+
9
+ This module defines the request and response models for the Prompts API
10
+ using Pydantic with Field descriptions for OpenAPI schema generation.
11
+ """
12
+
13
+ import re
14
+ import secrets
15
+
16
+ from pydantic import BaseModel, Field, field_validator, model_validator
17
+
18
+ from llama_stack_api.schema_utils import json_schema_type
19
+
20
+
21
+ @json_schema_type
22
+ class Prompt(BaseModel):
23
+ """A prompt resource representing a stored OpenAI Compatible prompt template in Llama Stack."""
24
+
25
+ prompt: str | None = Field(default=None, description="The system prompt with variable placeholders")
26
+ version: int = Field(description="Version (integer starting at 1, incremented on save)", ge=1)
27
+ prompt_id: str = Field(description="Unique identifier in format 'pmpt_<48-digit-hash>'")
28
+ variables: list[str] = Field(
29
+ default_factory=list, description="List of variable names that can be used in the prompt template"
30
+ )
31
+ is_default: bool = Field(
32
+ default=False, description="Boolean indicating whether this version is the default version"
33
+ )
34
+
35
+ @field_validator("prompt_id")
36
+ @classmethod
37
+ def validate_prompt_id(cls, prompt_id: str) -> str:
38
+ if not isinstance(prompt_id, str):
39
+ raise TypeError("prompt_id must be a string in format 'pmpt_<48-digit-hash>'")
40
+
41
+ if not prompt_id.startswith("pmpt_"):
42
+ raise ValueError("prompt_id must start with 'pmpt_' prefix")
43
+
44
+ hex_part = prompt_id[5:]
45
+ if len(hex_part) != 48:
46
+ raise ValueError("prompt_id must be in format 'pmpt_<48-digit-hash>' (48 lowercase hex chars)")
47
+
48
+ for char in hex_part:
49
+ if char not in "0123456789abcdef":
50
+ raise ValueError("prompt_id hex part must contain only lowercase hex characters [0-9a-f]")
51
+
52
+ return prompt_id
53
+
54
+ @field_validator("version")
55
+ @classmethod
56
+ def validate_version(cls, prompt_version: int) -> int:
57
+ if prompt_version < 1:
58
+ raise ValueError("version must be >= 1")
59
+ return prompt_version
60
+
61
+ @model_validator(mode="after")
62
+ def validate_prompt_variables(self):
63
+ """Validate that all variables used in the prompt are declared in the variables list."""
64
+ if not self.prompt:
65
+ return self
66
+
67
+ prompt_variables = set(re.findall(r"{{\s*(\w+)\s*}}", self.prompt))
68
+ declared_variables = set(self.variables)
69
+
70
+ undeclared = prompt_variables - declared_variables
71
+ if undeclared:
72
+ raise ValueError(f"Prompt contains undeclared variables: {sorted(undeclared)}")
73
+
74
+ return self
75
+
76
+ @classmethod
77
+ def generate_prompt_id(cls) -> str:
78
+ # Generate 48 hex characters (24 bytes)
79
+ random_bytes = secrets.token_bytes(24)
80
+ hex_string = random_bytes.hex()
81
+ return f"pmpt_{hex_string}"
82
+
83
+
84
+ @json_schema_type
85
+ class ListPromptsResponse(BaseModel):
86
+ """Response model to list prompts."""
87
+
88
+ data: list[Prompt]
89
+
90
+
91
+ # Request models for each endpoint
92
+
93
+
94
+ @json_schema_type
95
+ class ListPromptVersionsRequest(BaseModel):
96
+ """Request model for listing all versions of a prompt."""
97
+
98
+ prompt_id: str = Field(..., description="The identifier of the prompt to list versions for.")
99
+
100
+
101
+ @json_schema_type
102
+ class GetPromptRequest(BaseModel):
103
+ """Request model for getting a prompt by ID and optional version."""
104
+
105
+ prompt_id: str = Field(..., description="The identifier of the prompt to get.")
106
+ version: int | None = Field(default=None, description="The version of the prompt to get (defaults to latest).")
107
+
108
+
109
+ @json_schema_type
110
+ class CreatePromptRequest(BaseModel):
111
+ """Request model for creating a new prompt."""
112
+
113
+ prompt: str = Field(..., description="The prompt text content with variable placeholders.")
114
+ variables: list[str] | None = Field(
115
+ default=None, description="List of variable names that can be used in the prompt template."
116
+ )
117
+
118
+
119
+ @json_schema_type
120
+ class UpdatePromptBodyRequest(BaseModel):
121
+ """Request body model for updating a prompt."""
122
+
123
+ prompt: str = Field(..., description="The updated prompt text content.")
124
+ version: int = Field(..., description="The current version of the prompt being updated.")
125
+ variables: list[str] | None = Field(
126
+ default=None, description="Updated list of variable names that can be used in the prompt template."
127
+ )
128
+ set_as_default: bool = Field(default=True, description="Set the new version as the default (default=True).")
129
+
130
+
131
+ @json_schema_type
132
+ class UpdatePromptRequest(BaseModel):
133
+ """Request model for updating a prompt (combines path and body parameters)."""
134
+
135
+ prompt_id: str = Field(..., description="The identifier of the prompt to update.")
136
+ prompt: str = Field(..., description="The updated prompt text content.")
137
+ version: int = Field(..., description="The current version of the prompt being updated.")
138
+ variables: list[str] | None = Field(
139
+ default=None, description="Updated list of variable names that can be used in the prompt template."
140
+ )
141
+ set_as_default: bool = Field(default=True, description="Set the new version as the default (default=True).")
142
+
143
+
144
+ @json_schema_type
145
+ class DeletePromptRequest(BaseModel):
146
+ """Request model for deleting a prompt."""
147
+
148
+ prompt_id: str = Field(..., description="The identifier of the prompt to delete.")
149
+
150
+
151
+ @json_schema_type
152
+ class SetDefaultVersionBodyRequest(BaseModel):
153
+ """Request body model for setting the default version of a prompt."""
154
+
155
+ version: int = Field(..., description="The version to set as default.")
156
+
157
+
158
+ @json_schema_type
159
+ class SetDefaultVersionRequest(BaseModel):
160
+ """Request model for setting the default version of a prompt (combines path and body parameters)."""
161
+
162
+ prompt_id: str = Field(..., description="The identifier of the prompt.")
163
+ version: int = Field(..., description="The version to set as default.")
164
+
165
+
166
+ __all__ = [
167
+ "CreatePromptRequest",
168
+ "DeletePromptRequest",
169
+ "GetPromptRequest",
170
+ "ListPromptVersionsRequest",
171
+ "ListPromptsResponse",
172
+ "Prompt",
173
+ "SetDefaultVersionBodyRequest",
174
+ "SetDefaultVersionRequest",
175
+ "UpdatePromptBodyRequest",
176
+ "UpdatePromptRequest",
177
+ ]
@@ -19,7 +19,6 @@ class ResourceType(StrEnum):
19
19
  tool = "tool"
20
20
  tool_group = "tool_group"
21
21
  prompt = "prompt"
22
- connector = "connector"
23
22
 
24
23
 
25
24
  class Resource(BaseModel):
@@ -0,0 +1,37 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Safety API protocol and models.
8
+
9
+ This module contains the Safety protocol definition for content moderation and safety shields.
10
+ Pydantic models are defined in llama_stack_api.safety.models.
11
+ The FastAPI router is defined in llama_stack_api.safety.fastapi_routes.
12
+ """
13
+
14
+ from . import fastapi_routes
15
+ from .api import Safety
16
+ from .datatypes import (
17
+ ModerationObject,
18
+ ModerationObjectResults,
19
+ RunShieldResponse,
20
+ SafetyViolation,
21
+ ShieldStore,
22
+ ViolationLevel,
23
+ )
24
+ from .models import RunModerationRequest, RunShieldRequest
25
+
26
+ __all__ = [
27
+ "Safety",
28
+ "ShieldStore",
29
+ "ModerationObject",
30
+ "ModerationObjectResults",
31
+ "ViolationLevel",
32
+ "SafetyViolation",
33
+ "RunShieldResponse",
34
+ "RunShieldRequest",
35
+ "RunModerationRequest",
36
+ "fastapi_routes",
37
+ ]
@@ -0,0 +1,29 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Protocol, runtime_checkable
8
+
9
+ from llama_stack_api.safety.datatypes import ModerationObject, RunShieldResponse, ShieldStore
10
+
11
+ from .models import RunModerationRequest, RunShieldRequest
12
+
13
+
14
+ @runtime_checkable
15
+ class Safety(Protocol):
16
+ """Safety API for content moderation and safety shields.
17
+
18
+ OpenAI-compatible Moderations API with additional shield capabilities.
19
+ """
20
+
21
+ shield_store: ShieldStore
22
+
23
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
24
+ """Run a safety shield on messages."""
25
+ ...
26
+
27
+ async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
28
+ """Classify if inputs are potentially harmful."""
29
+ ...
@@ -0,0 +1,83 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from enum import Enum
8
+ from typing import Any, Protocol
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+ from llama_stack_api.schema_utils import json_schema_type
13
+ from llama_stack_api.shields import GetShieldRequest, Shield
14
+
15
+
16
+ @json_schema_type
17
+ class ModerationObjectResults(BaseModel):
18
+ """A moderation result object containing flagged status and category information."""
19
+
20
+ flagged: bool = Field(..., description="Whether any of the below categories are flagged")
21
+ categories: dict[str, bool] | None = Field(
22
+ None, description="A dictionary of the categories, and whether they are flagged or not"
23
+ )
24
+ category_applied_input_types: dict[str, list[str]] | None = Field(
25
+ None, description="A dictionary of the categories along with the input type(s) that the score applies to"
26
+ )
27
+ category_scores: dict[str, float] | None = Field(
28
+ None, description="A dictionary of the categories along with their scores as predicted by model"
29
+ )
30
+ user_message: str | None = Field(None, description="A message to convey to the user about the moderation result")
31
+ metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata about the moderation")
32
+
33
+
34
+ @json_schema_type
35
+ class ModerationObject(BaseModel):
36
+ """A moderation object containing the results of content classification."""
37
+
38
+ id: str = Field(..., description="The unique identifier for the moderation request")
39
+ model: str = Field(..., description="The model used to generate the moderation results")
40
+ results: list[ModerationObjectResults] = Field(..., description="A list of moderation result objects")
41
+
42
+
43
+ @json_schema_type
44
+ class ViolationLevel(Enum):
45
+ """Severity level of a safety violation."""
46
+
47
+ INFO = "info" # Informational level violation that does not require action
48
+ WARN = "warn" # Warning level violation that suggests caution but allows continuation
49
+ ERROR = "error" # Error level violation that requires blocking or intervention
50
+
51
+
52
+ @json_schema_type
53
+ class SafetyViolation(BaseModel):
54
+ """Details of a safety violation detected by content moderation."""
55
+
56
+ violation_level: ViolationLevel = Field(..., description="Severity level of the violation")
57
+ user_message: str | None = Field(None, description="Message to convey to the user about the violation")
58
+ metadata: dict[str, Any] = Field(
59
+ default_factory=dict, description="Additional metadata including specific violation codes"
60
+ )
61
+
62
+
63
+ @json_schema_type
64
+ class RunShieldResponse(BaseModel):
65
+ """Response from running a safety shield."""
66
+
67
+ violation: SafetyViolation | None = Field(None, description="Safety violation detected by the shield, if any")
68
+
69
+
70
+ class ShieldStore(Protocol):
71
+ """Protocol for accessing shields."""
72
+
73
+ async def get_shield(self, request: GetShieldRequest) -> Shield: ...
74
+
75
+
76
+ __all__ = [
77
+ "ModerationObjectResults",
78
+ "ModerationObject",
79
+ "ViolationLevel",
80
+ "SafetyViolation",
81
+ "RunShieldResponse",
82
+ "ShieldStore",
83
+ ]
@@ -0,0 +1,55 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Annotated
8
+
9
+ from fastapi import APIRouter, Body
10
+
11
+ from llama_stack_api.router_utils import standard_responses
12
+ from llama_stack_api.version import LLAMA_STACK_API_V1
13
+
14
+ from .api import Safety
15
+ from .datatypes import ModerationObject, RunShieldResponse
16
+ from .models import RunModerationRequest, RunShieldRequest
17
+
18
+
19
+ def create_router(impl: Safety) -> APIRouter:
20
+ """Create a FastAPI router for the Safety API."""
21
+ router = APIRouter(
22
+ prefix=f"/{LLAMA_STACK_API_V1}",
23
+ tags=["Safety"],
24
+ responses=standard_responses,
25
+ )
26
+
27
+ @router.post(
28
+ "/safety/run-shield",
29
+ response_model=RunShieldResponse,
30
+ summary="Run Shield",
31
+ description="Run a safety shield on messages to check for policy violations.",
32
+ responses={
33
+ 200: {"description": "The shield response indicating any violations detected."},
34
+ },
35
+ )
36
+ async def run_shield(
37
+ request: Annotated[RunShieldRequest, Body(...)],
38
+ ) -> RunShieldResponse:
39
+ return await impl.run_shield(request)
40
+
41
+ @router.post(
42
+ "/moderations",
43
+ response_model=ModerationObject,
44
+ summary="Create Moderation",
45
+ description="Classifies if text inputs are potentially harmful. OpenAI-compatible endpoint.",
46
+ responses={
47
+ 200: {"description": "The moderation results for the input."},
48
+ },
49
+ )
50
+ async def run_moderation(
51
+ request: Annotated[RunModerationRequest, Body(...)],
52
+ ) -> ModerationObject:
53
+ return await impl.run_moderation(request)
54
+
55
+ return router
@@ -0,0 +1,38 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ from llama_stack_api.inference import OpenAIMessageParam
10
+ from llama_stack_api.schema_utils import json_schema_type
11
+
12
+
13
+ @json_schema_type
14
+ class RunShieldRequest(BaseModel):
15
+ """Request model for running a safety shield."""
16
+
17
+ shield_id: str = Field(..., description="The identifier of the shield to run", min_length=1)
18
+ messages: list[OpenAIMessageParam] = Field(..., description="The messages to run the shield on")
19
+
20
+
21
+ @json_schema_type
22
+ class RunModerationRequest(BaseModel):
23
+ """Request model for running content moderation."""
24
+
25
+ input: str | list[str] = Field(
26
+ ...,
27
+ description="Input (or inputs) to classify. Can be a single string or an array of strings.",
28
+ )
29
+ model: str | None = Field(
30
+ None,
31
+ description="The content moderation model to use. If not specified, the default shield will be used.",
32
+ )
33
+
34
+
35
+ __all__ = [
36
+ "RunShieldRequest",
37
+ "RunModerationRequest",
38
+ ]
@@ -149,7 +149,6 @@ class WebMethod:
149
149
  raw_bytes_request_body: bool | None = False
150
150
  # A descriptive name of the corresponding span created by tracing
151
151
  descriptive_name: str | None = None
152
- required_scope: str | None = None
153
152
  deprecated: bool | None = False
154
153
  require_authentication: bool | None = True
155
154
 
@@ -166,7 +165,6 @@ def webmethod(
166
165
  response_examples: list[Any] | None = None,
167
166
  raw_bytes_request_body: bool | None = False,
168
167
  descriptive_name: str | None = None,
169
- required_scope: str | None = None,
170
168
  deprecated: bool | None = False,
171
169
  require_authentication: bool | None = True,
172
170
  ) -> Callable[[CallableT], CallableT]:
@@ -177,7 +175,6 @@ def webmethod(
177
175
  :param public: True if the operation can be invoked without prior authentication.
178
176
  :param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
179
177
  :param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
180
- :param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
181
178
  :param require_authentication: Whether this endpoint requires authentication (default True).
182
179
  """
183
180
 
@@ -191,7 +188,6 @@ def webmethod(
191
188
  response_examples=response_examples,
192
189
  raw_bytes_request_body=raw_bytes_request_body,
193
190
  descriptive_name=descriptive_name,
194
- required_scope=required_scope,
195
191
  deprecated=deprecated,
196
192
  require_authentication=require_authentication if require_authentication is not None else True,
197
193
  )
@@ -206,3 +202,50 @@ def webmethod(
206
202
  return func
207
203
 
208
204
  return wrap
205
+
206
+
207
+ def remove_null_from_anyof(schema: dict, *, add_nullable: bool = False) -> None:
208
+ """Remove null type from anyOf and optionally add nullable flag.
209
+
210
+ Converts Pydantic's default OpenAPI 3.1 style:
211
+ anyOf: [{type: X, enum: [...]}, {type: null}]
212
+
213
+ To flattened format:
214
+ type: X
215
+ enum: [...]
216
+ nullable: true # only if add_nullable=True
217
+
218
+ Args:
219
+ schema: The JSON schema dict to modify in-place
220
+ add_nullable: If True, adds 'nullable: true' when null was present.
221
+ Use True for OpenAPI 3.0 compatibility with OpenAI's spec.
222
+ """
223
+ # Handle anyOf format: anyOf: [{type: string, enum: [...]}, {type: null}]
224
+ if "anyOf" in schema:
225
+ non_null = [s for s in schema["anyOf"] if s.get("type") != "null"]
226
+ has_null = len(non_null) < len(schema["anyOf"])
227
+
228
+ if len(non_null) == 1:
229
+ # Flatten to single type
230
+ only_schema = non_null[0]
231
+ schema.pop("anyOf")
232
+ schema.update(only_schema)
233
+ if has_null and add_nullable:
234
+ schema["nullable"] = True
235
+
236
+ # Handle OpenAPI 3.1 format: type: ['string', 'null']
237
+ elif isinstance(schema.get("type"), list) and "null" in schema["type"]:
238
+ has_null = "null" in schema["type"]
239
+ schema["type"].remove("null")
240
+ if len(schema["type"]) == 1:
241
+ schema["type"] = schema["type"][0]
242
+ if has_null and add_nullable:
243
+ schema["nullable"] = True
244
+
245
+
246
+ def nullable_openai_style(schema: dict) -> None:
247
+ """Shorthand for remove_null_from_anyof with add_nullable=True.
248
+
249
+ Use this for fields that need OpenAPI 3.0 nullable style to match OpenAI's spec.
250
+ """
251
+ remove_null_from_anyof(schema, add_nullable=True)
@@ -0,0 +1,66 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Scoring API protocol and models.
8
+
9
+ This module contains the Scoring protocol definition.
10
+ Pydantic models are defined in llama_stack_api.scoring.models.
11
+ The FastAPI router is defined in llama_stack_api.scoring.fastapi_routes.
12
+ """
13
+
14
+ # Import fastapi_routes for router factory access
15
+ # Import scoring_functions for re-export
16
+ from llama_stack_api.scoring_functions import (
17
+ AggregationFunctionType,
18
+ BasicScoringFnParams,
19
+ CommonScoringFnFields,
20
+ ListScoringFunctionsResponse,
21
+ LLMAsJudgeScoringFnParams,
22
+ RegexParserScoringFnParams,
23
+ ScoringFn,
24
+ ScoringFnInput,
25
+ ScoringFnParams,
26
+ ScoringFnParamsType,
27
+ ScoringFunctions,
28
+ )
29
+
30
+ from . import fastapi_routes
31
+
32
+ # Import protocol for FastAPI router
33
+ from .api import Scoring, ScoringFunctionStore
34
+
35
+ # Import models for re-export
36
+ from .models import (
37
+ ScoreBatchRequest,
38
+ ScoreBatchResponse,
39
+ ScoreRequest,
40
+ ScoreResponse,
41
+ ScoringResult,
42
+ ScoringResultRow,
43
+ )
44
+
45
+ __all__ = [
46
+ "Scoring",
47
+ "ScoringFunctionStore",
48
+ "ScoringResult",
49
+ "ScoringResultRow",
50
+ "ScoreBatchResponse",
51
+ "ScoreResponse",
52
+ "ScoreRequest",
53
+ "ScoreBatchRequest",
54
+ "AggregationFunctionType",
55
+ "BasicScoringFnParams",
56
+ "CommonScoringFnFields",
57
+ "LLMAsJudgeScoringFnParams",
58
+ "ListScoringFunctionsResponse",
59
+ "RegexParserScoringFnParams",
60
+ "ScoringFn",
61
+ "ScoringFnInput",
62
+ "ScoringFnParams",
63
+ "ScoringFnParamsType",
64
+ "ScoringFunctions",
65
+ "fastapi_routes",
66
+ ]
@@ -0,0 +1,35 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Scoring API protocol definition.
8
+
9
+ This module contains the Scoring protocol definition.
10
+ Pydantic models are defined in llama_stack_api.scoring.models.
11
+ The FastAPI router is defined in llama_stack_api.scoring.fastapi_routes.
12
+ """
13
+
14
+ from typing import Protocol, runtime_checkable
15
+
16
+ from llama_stack_api.scoring_functions import ScoringFn
17
+
18
+ from .models import ScoreBatchRequest, ScoreBatchResponse, ScoreRequest, ScoreResponse
19
+
20
+
21
+ class ScoringFunctionStore(Protocol):
22
+ """Protocol for storing and retrieving scoring functions."""
23
+
24
+ def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
25
+
26
+
27
+ @runtime_checkable
28
+ class Scoring(Protocol):
29
+ """Protocol for scoring operations."""
30
+
31
+ scoring_function_store: ScoringFunctionStore
32
+
33
+ async def score_batch(self, request: ScoreBatchRequest) -> ScoreBatchResponse: ...
34
+
35
+ async def score(self, request: ScoreRequest) -> ScoreResponse: ...
@@ -0,0 +1,67 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """FastAPI router for the Scoring API.
8
+
9
+ This module defines the FastAPI router for the Scoring API using standard
10
+ FastAPI route decorators.
11
+ """
12
+
13
+ from typing import Annotated
14
+
15
+ from fastapi import APIRouter, Body
16
+
17
+ from llama_stack_api.router_utils import standard_responses
18
+ from llama_stack_api.version import LLAMA_STACK_API_V1
19
+
20
+ from .api import Scoring
21
+ from .models import ScoreBatchRequest, ScoreBatchResponse, ScoreRequest, ScoreResponse
22
+
23
+
24
+ def create_router(impl: Scoring) -> APIRouter:
25
+ """Create a FastAPI router for the Scoring API.
26
+
27
+ Args:
28
+ impl: The Scoring implementation instance
29
+
30
+ Returns:
31
+ APIRouter configured for the Scoring API
32
+ """
33
+ router = APIRouter(
34
+ prefix=f"/{LLAMA_STACK_API_V1}",
35
+ tags=["Scoring"],
36
+ responses=standard_responses,
37
+ )
38
+
39
+ @router.post(
40
+ "/scoring/score",
41
+ response_model=ScoreResponse,
42
+ summary="Score a list of rows.",
43
+ description="Score a list of rows.",
44
+ responses={
45
+ 200: {"description": "A ScoreResponse object containing rows and aggregated results."},
46
+ },
47
+ )
48
+ async def score(
49
+ request: Annotated[ScoreRequest, Body(...)],
50
+ ) -> ScoreResponse:
51
+ return await impl.score(request)
52
+
53
+ @router.post(
54
+ "/scoring/score-batch",
55
+ response_model=ScoreBatchResponse,
56
+ summary="Score a batch of rows.",
57
+ description="Score a batch of rows.",
58
+ responses={
59
+ 200: {"description": "A ScoreBatchResponse."},
60
+ },
61
+ )
62
+ async def score_batch(
63
+ request: Annotated[ScoreBatchRequest, Body(...)],
64
+ ) -> ScoreBatchResponse:
65
+ return await impl.score_batch(request)
66
+
67
+ return router