llama-stack-api 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. llama_stack_api/__init__.py +1100 -0
  2. llama_stack_api/admin/__init__.py +45 -0
  3. llama_stack_api/admin/api.py +72 -0
  4. llama_stack_api/admin/fastapi_routes.py +117 -0
  5. llama_stack_api/admin/models.py +113 -0
  6. llama_stack_api/agents/__init__.py +38 -0
  7. llama_stack_api/agents/api.py +52 -0
  8. llama_stack_api/agents/fastapi_routes.py +268 -0
  9. llama_stack_api/agents/models.py +181 -0
  10. llama_stack_api/batches/__init__.py +40 -0
  11. llama_stack_api/batches/api.py +53 -0
  12. llama_stack_api/batches/fastapi_routes.py +113 -0
  13. llama_stack_api/batches/models.py +78 -0
  14. llama_stack_api/benchmarks/__init__.py +43 -0
  15. llama_stack_api/benchmarks/api.py +39 -0
  16. llama_stack_api/benchmarks/fastapi_routes.py +109 -0
  17. llama_stack_api/benchmarks/models.py +109 -0
  18. llama_stack_api/common/__init__.py +5 -0
  19. llama_stack_api/common/content_types.py +101 -0
  20. llama_stack_api/common/errors.py +110 -0
  21. llama_stack_api/common/job_types.py +38 -0
  22. llama_stack_api/common/responses.py +77 -0
  23. llama_stack_api/common/training_types.py +47 -0
  24. llama_stack_api/common/type_system.py +146 -0
  25. llama_stack_api/connectors/__init__.py +38 -0
  26. llama_stack_api/connectors/api.py +50 -0
  27. llama_stack_api/connectors/fastapi_routes.py +103 -0
  28. llama_stack_api/connectors/models.py +103 -0
  29. llama_stack_api/conversations/__init__.py +61 -0
  30. llama_stack_api/conversations/api.py +44 -0
  31. llama_stack_api/conversations/fastapi_routes.py +177 -0
  32. llama_stack_api/conversations/models.py +245 -0
  33. llama_stack_api/datasetio/__init__.py +34 -0
  34. llama_stack_api/datasetio/api.py +42 -0
  35. llama_stack_api/datasetio/fastapi_routes.py +94 -0
  36. llama_stack_api/datasetio/models.py +48 -0
  37. llama_stack_api/datasets/__init__.py +61 -0
  38. llama_stack_api/datasets/api.py +35 -0
  39. llama_stack_api/datasets/fastapi_routes.py +104 -0
  40. llama_stack_api/datasets/models.py +152 -0
  41. llama_stack_api/datatypes.py +373 -0
  42. llama_stack_api/eval/__init__.py +55 -0
  43. llama_stack_api/eval/api.py +51 -0
  44. llama_stack_api/eval/compat.py +300 -0
  45. llama_stack_api/eval/fastapi_routes.py +126 -0
  46. llama_stack_api/eval/models.py +141 -0
  47. llama_stack_api/file_processors/__init__.py +27 -0
  48. llama_stack_api/file_processors/api.py +64 -0
  49. llama_stack_api/file_processors/fastapi_routes.py +78 -0
  50. llama_stack_api/file_processors/models.py +42 -0
  51. llama_stack_api/files/__init__.py +35 -0
  52. llama_stack_api/files/api.py +51 -0
  53. llama_stack_api/files/fastapi_routes.py +124 -0
  54. llama_stack_api/files/models.py +107 -0
  55. llama_stack_api/inference/__init__.py +207 -0
  56. llama_stack_api/inference/api.py +93 -0
  57. llama_stack_api/inference/fastapi_routes.py +243 -0
  58. llama_stack_api/inference/models.py +1035 -0
  59. llama_stack_api/inspect_api/__init__.py +37 -0
  60. llama_stack_api/inspect_api/api.py +25 -0
  61. llama_stack_api/inspect_api/fastapi_routes.py +76 -0
  62. llama_stack_api/inspect_api/models.py +28 -0
  63. llama_stack_api/internal/__init__.py +9 -0
  64. llama_stack_api/internal/kvstore.py +28 -0
  65. llama_stack_api/internal/sqlstore.py +81 -0
  66. llama_stack_api/models/__init__.py +47 -0
  67. llama_stack_api/models/api.py +38 -0
  68. llama_stack_api/models/fastapi_routes.py +104 -0
  69. llama_stack_api/models/models.py +157 -0
  70. llama_stack_api/openai_responses.py +1494 -0
  71. llama_stack_api/post_training/__init__.py +73 -0
  72. llama_stack_api/post_training/api.py +36 -0
  73. llama_stack_api/post_training/fastapi_routes.py +116 -0
  74. llama_stack_api/post_training/models.py +339 -0
  75. llama_stack_api/prompts/__init__.py +47 -0
  76. llama_stack_api/prompts/api.py +44 -0
  77. llama_stack_api/prompts/fastapi_routes.py +163 -0
  78. llama_stack_api/prompts/models.py +177 -0
  79. llama_stack_api/providers/__init__.py +33 -0
  80. llama_stack_api/providers/api.py +16 -0
  81. llama_stack_api/providers/fastapi_routes.py +57 -0
  82. llama_stack_api/providers/models.py +24 -0
  83. llama_stack_api/rag_tool.py +168 -0
  84. llama_stack_api/resource.py +36 -0
  85. llama_stack_api/router_utils.py +160 -0
  86. llama_stack_api/safety/__init__.py +37 -0
  87. llama_stack_api/safety/api.py +29 -0
  88. llama_stack_api/safety/datatypes.py +83 -0
  89. llama_stack_api/safety/fastapi_routes.py +55 -0
  90. llama_stack_api/safety/models.py +38 -0
  91. llama_stack_api/schema_utils.py +251 -0
  92. llama_stack_api/scoring/__init__.py +66 -0
  93. llama_stack_api/scoring/api.py +35 -0
  94. llama_stack_api/scoring/fastapi_routes.py +67 -0
  95. llama_stack_api/scoring/models.py +81 -0
  96. llama_stack_api/scoring_functions/__init__.py +50 -0
  97. llama_stack_api/scoring_functions/api.py +39 -0
  98. llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
  99. llama_stack_api/scoring_functions/models.py +214 -0
  100. llama_stack_api/shields/__init__.py +41 -0
  101. llama_stack_api/shields/api.py +39 -0
  102. llama_stack_api/shields/fastapi_routes.py +104 -0
  103. llama_stack_api/shields/models.py +74 -0
  104. llama_stack_api/tools.py +226 -0
  105. llama_stack_api/validators.py +46 -0
  106. llama_stack_api/vector_io/__init__.py +88 -0
  107. llama_stack_api/vector_io/api.py +234 -0
  108. llama_stack_api/vector_io/fastapi_routes.py +447 -0
  109. llama_stack_api/vector_io/models.py +663 -0
  110. llama_stack_api/vector_stores.py +53 -0
  111. llama_stack_api/version.py +9 -0
  112. {llama_stack_api-0.4.3.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
  113. llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
  114. llama_stack_api-0.5.0rc1.dist-info/top_level.txt +1 -0
  115. llama_stack_api-0.4.3.dist-info/RECORD +0 -4
  116. llama_stack_api-0.4.3.dist-info/top_level.txt +0 -1
  117. {llama_stack_api-0.4.3.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,66 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Scoring API protocol and models.
8
+
9
+ This module contains the Scoring protocol definition.
10
+ Pydantic models are defined in llama_stack_api.scoring.models.
11
+ The FastAPI router is defined in llama_stack_api.scoring.fastapi_routes.
12
+ """
13
+
14
+ # Import fastapi_routes for router factory access
15
+ # Import scoring_functions for re-export
16
+ from llama_stack_api.scoring_functions import (
17
+ AggregationFunctionType,
18
+ BasicScoringFnParams,
19
+ CommonScoringFnFields,
20
+ ListScoringFunctionsResponse,
21
+ LLMAsJudgeScoringFnParams,
22
+ RegexParserScoringFnParams,
23
+ ScoringFn,
24
+ ScoringFnInput,
25
+ ScoringFnParams,
26
+ ScoringFnParamsType,
27
+ ScoringFunctions,
28
+ )
29
+
30
+ from . import fastapi_routes
31
+
32
+ # Import protocol for FastAPI router
33
+ from .api import Scoring, ScoringFunctionStore
34
+
35
+ # Import models for re-export
36
+ from .models import (
37
+ ScoreBatchRequest,
38
+ ScoreBatchResponse,
39
+ ScoreRequest,
40
+ ScoreResponse,
41
+ ScoringResult,
42
+ ScoringResultRow,
43
+ )
44
+
45
+ __all__ = [
46
+ "Scoring",
47
+ "ScoringFunctionStore",
48
+ "ScoringResult",
49
+ "ScoringResultRow",
50
+ "ScoreBatchResponse",
51
+ "ScoreResponse",
52
+ "ScoreRequest",
53
+ "ScoreBatchRequest",
54
+ "AggregationFunctionType",
55
+ "BasicScoringFnParams",
56
+ "CommonScoringFnFields",
57
+ "LLMAsJudgeScoringFnParams",
58
+ "ListScoringFunctionsResponse",
59
+ "RegexParserScoringFnParams",
60
+ "ScoringFn",
61
+ "ScoringFnInput",
62
+ "ScoringFnParams",
63
+ "ScoringFnParamsType",
64
+ "ScoringFunctions",
65
+ "fastapi_routes",
66
+ ]
@@ -0,0 +1,35 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Scoring API protocol definition.
8
+
9
+ This module contains the Scoring protocol definition.
10
+ Pydantic models are defined in llama_stack_api.scoring.models.
11
+ The FastAPI router is defined in llama_stack_api.scoring.fastapi_routes.
12
+ """
13
+
14
+ from typing import Protocol, runtime_checkable
15
+
16
+ from llama_stack_api.scoring_functions import ScoringFn
17
+
18
+ from .models import ScoreBatchRequest, ScoreBatchResponse, ScoreRequest, ScoreResponse
19
+
20
+
21
+ class ScoringFunctionStore(Protocol):
22
+ """Protocol for storing and retrieving scoring functions."""
23
+
24
+ def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
25
+
26
+
27
+ @runtime_checkable
28
+ class Scoring(Protocol):
29
+ """Protocol for scoring operations."""
30
+
31
+ scoring_function_store: ScoringFunctionStore
32
+
33
+ async def score_batch(self, request: ScoreBatchRequest) -> ScoreBatchResponse: ...
34
+
35
+ async def score(self, request: ScoreRequest) -> ScoreResponse: ...
@@ -0,0 +1,67 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """FastAPI router for the Scoring API.
8
+
9
+ This module defines the FastAPI router for the Scoring API using standard
10
+ FastAPI route decorators.
11
+ """
12
+
13
+ from typing import Annotated
14
+
15
+ from fastapi import APIRouter, Body
16
+
17
+ from llama_stack_api.router_utils import standard_responses
18
+ from llama_stack_api.version import LLAMA_STACK_API_V1
19
+
20
+ from .api import Scoring
21
+ from .models import ScoreBatchRequest, ScoreBatchResponse, ScoreRequest, ScoreResponse
22
+
23
+
24
+ def create_router(impl: Scoring) -> APIRouter:
25
+ """Create a FastAPI router for the Scoring API.
26
+
27
+ Args:
28
+ impl: The Scoring implementation instance
29
+
30
+ Returns:
31
+ APIRouter configured for the Scoring API
32
+ """
33
+ router = APIRouter(
34
+ prefix=f"/{LLAMA_STACK_API_V1}",
35
+ tags=["Scoring"],
36
+ responses=standard_responses,
37
+ )
38
+
39
+ @router.post(
40
+ "/scoring/score",
41
+ response_model=ScoreResponse,
42
+ summary="Score a list of rows.",
43
+ description="Score a list of rows.",
44
+ responses={
45
+ 200: {"description": "A ScoreResponse object containing rows and aggregated results."},
46
+ },
47
+ )
48
+ async def score(
49
+ request: Annotated[ScoreRequest, Body(...)],
50
+ ) -> ScoreResponse:
51
+ return await impl.score(request)
52
+
53
+ @router.post(
54
+ "/scoring/score-batch",
55
+ response_model=ScoreBatchResponse,
56
+ summary="Score a batch of rows.",
57
+ description="Score a batch of rows.",
58
+ responses={
59
+ 200: {"description": "A ScoreBatchResponse."},
60
+ },
61
+ )
62
+ async def score_batch(
63
+ request: Annotated[ScoreBatchRequest, Body(...)],
64
+ ) -> ScoreBatchResponse:
65
+ return await impl.score_batch(request)
66
+
67
+ return router
@@ -0,0 +1,81 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Pydantic models for Scoring API requests and responses.
8
+
9
+ This module defines the request and response models for the Scoring API
10
+ using Pydantic with Field descriptions for OpenAPI schema generation.
11
+ """
12
+
13
+ from typing import Any
14
+
15
+ from pydantic import BaseModel, Field
16
+
17
+ from llama_stack_api.schema_utils import json_schema_type
18
+ from llama_stack_api.scoring_functions import ScoringFnParams
19
+
20
+ # mapping of metric to value
21
+ ScoringResultRow = dict[str, Any]
22
+
23
+
24
+ @json_schema_type
25
+ class ScoringResult(BaseModel):
26
+ """
27
+ A scoring result for a single row.
28
+ """
29
+
30
+ score_rows: list[ScoringResultRow] = Field(
31
+ ..., description="The scoring result for each row. Each row is a map of column name to value."
32
+ )
33
+ aggregated_results: dict[str, Any] = Field(..., description="Map of metric name to aggregated value")
34
+
35
+
36
+ @json_schema_type
37
+ class ScoreBatchResponse(BaseModel):
38
+ """Response from batch scoring operations on datasets."""
39
+
40
+ dataset_id: str | None = Field(default=None, description="(Optional) The identifier of the dataset that was scored")
41
+ results: dict[str, ScoringResult] = Field(..., description="A map of scoring function name to ScoringResult")
42
+
43
+
44
+ @json_schema_type
45
+ class ScoreResponse(BaseModel):
46
+ """
47
+ The response from scoring.
48
+ """
49
+
50
+ results: dict[str, ScoringResult] = Field(..., description="A map of scoring function name to ScoringResult.")
51
+
52
+
53
+ @json_schema_type
54
+ class ScoreRequest(BaseModel):
55
+ """Request model for scoring a list of rows."""
56
+
57
+ input_rows: list[dict[str, Any]] = Field(..., description="The rows to score.")
58
+ scoring_functions: dict[str, ScoringFnParams | None] = Field(
59
+ ..., description="The scoring functions to use for the scoring."
60
+ )
61
+
62
+
63
+ @json_schema_type
64
+ class ScoreBatchRequest(BaseModel):
65
+ """Request model for scoring a batch of rows from a dataset."""
66
+
67
+ dataset_id: str = Field(..., description="The ID of the dataset to score.")
68
+ scoring_functions: dict[str, ScoringFnParams | None] = Field(
69
+ ..., description="The scoring functions to use for the scoring."
70
+ )
71
+ save_results_dataset: bool = Field(default=False, description="Whether to save the results to a dataset.")
72
+
73
+
74
+ __all__ = [
75
+ "ScoringResult",
76
+ "ScoringResultRow",
77
+ "ScoreBatchResponse",
78
+ "ScoreResponse",
79
+ "ScoreRequest",
80
+ "ScoreBatchRequest",
81
+ ]
@@ -0,0 +1,50 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """ScoringFunctions API protocol and models.
8
+
9
+ This module contains the ScoringFunctions protocol definition.
10
+ Pydantic models are defined in llama_stack_api.scoring_functions.models.
11
+ The FastAPI router is defined in llama_stack_api.scoring_functions.fastapi_routes.
12
+ """
13
+
14
+ from . import fastapi_routes
15
+ from .api import ScoringFunctions
16
+ from .models import (
17
+ AggregationFunctionType,
18
+ BasicScoringFnParams,
19
+ CommonScoringFnFields,
20
+ GetScoringFunctionRequest,
21
+ ListScoringFunctionsRequest,
22
+ ListScoringFunctionsResponse,
23
+ LLMAsJudgeScoringFnParams,
24
+ RegexParserScoringFnParams,
25
+ RegisterScoringFunctionRequest,
26
+ ScoringFn,
27
+ ScoringFnInput,
28
+ ScoringFnParams,
29
+ ScoringFnParamsType,
30
+ UnregisterScoringFunctionRequest,
31
+ )
32
+
33
+ __all__ = [
34
+ "ScoringFunctions",
35
+ "ScoringFn",
36
+ "ScoringFnInput",
37
+ "ScoringFnParams",
38
+ "ScoringFnParamsType",
39
+ "AggregationFunctionType",
40
+ "LLMAsJudgeScoringFnParams",
41
+ "RegexParserScoringFnParams",
42
+ "BasicScoringFnParams",
43
+ "CommonScoringFnFields",
44
+ "ListScoringFunctionsResponse",
45
+ "ListScoringFunctionsRequest",
46
+ "GetScoringFunctionRequest",
47
+ "RegisterScoringFunctionRequest",
48
+ "UnregisterScoringFunctionRequest",
49
+ "fastapi_routes",
50
+ ]
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Protocol, runtime_checkable
8
+
9
+ from .models import (
10
+ GetScoringFunctionRequest,
11
+ ListScoringFunctionsRequest,
12
+ ListScoringFunctionsResponse,
13
+ RegisterScoringFunctionRequest,
14
+ ScoringFn,
15
+ UnregisterScoringFunctionRequest,
16
+ )
17
+
18
+
19
+ @runtime_checkable
20
+ class ScoringFunctions(Protocol):
21
+ async def list_scoring_functions(
22
+ self,
23
+ request: ListScoringFunctionsRequest,
24
+ ) -> ListScoringFunctionsResponse: ...
25
+
26
+ async def get_scoring_function(
27
+ self,
28
+ request: GetScoringFunctionRequest,
29
+ ) -> ScoringFn: ...
30
+
31
+ async def register_scoring_function(
32
+ self,
33
+ request: RegisterScoringFunctionRequest,
34
+ ) -> None: ...
35
+
36
+ async def unregister_scoring_function(
37
+ self,
38
+ request: UnregisterScoringFunctionRequest,
39
+ ) -> None: ...
@@ -0,0 +1,108 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """FastAPI router for the ScoringFunctions API.
8
+
9
+ This module defines the FastAPI router for the ScoringFunctions API using standard
10
+ FastAPI route decorators.
11
+
12
+ The router is defined in the API package to keep all API-related code together.
13
+ """
14
+
15
+ from typing import Annotated
16
+
17
+ from fastapi import APIRouter, Body, Depends
18
+
19
+ from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses
20
+ from llama_stack_api.version import LLAMA_STACK_API_V1
21
+
22
+ from .api import ScoringFunctions
23
+ from .models import (
24
+ GetScoringFunctionRequest,
25
+ ListScoringFunctionsRequest,
26
+ ListScoringFunctionsResponse,
27
+ RegisterScoringFunctionRequest,
28
+ ScoringFn,
29
+ UnregisterScoringFunctionRequest,
30
+ )
31
+
32
+ get_list_scoring_functions_request = create_query_dependency(ListScoringFunctionsRequest)
33
+ get_get_scoring_function_request = create_path_dependency(GetScoringFunctionRequest)
34
+ get_unregister_scoring_function_request = create_path_dependency(UnregisterScoringFunctionRequest)
35
+
36
+
37
+ def create_router(impl: ScoringFunctions) -> APIRouter:
38
+ """Create a FastAPI router for the ScoringFunctions API.
39
+
40
+ Args:
41
+ impl: The ScoringFunctions implementation instance
42
+
43
+ Returns:
44
+ APIRouter configured for the ScoringFunctions API
45
+ """
46
+ router = APIRouter(
47
+ prefix=f"/{LLAMA_STACK_API_V1}",
48
+ tags=["Scoring Functions"],
49
+ responses=standard_responses,
50
+ )
51
+
52
+ @router.get(
53
+ "/scoring-functions",
54
+ response_model=ListScoringFunctionsResponse,
55
+ summary="List all scoring functions.",
56
+ description="List all scoring functions.",
57
+ responses={
58
+ 200: {"description": "A ListScoringFunctionsResponse."},
59
+ },
60
+ )
61
+ async def list_scoring_functions(
62
+ request: Annotated[ListScoringFunctionsRequest, Depends(get_list_scoring_functions_request)],
63
+ ) -> ListScoringFunctionsResponse:
64
+ return await impl.list_scoring_functions(request)
65
+
66
+ @router.get(
67
+ "/scoring-functions/{scoring_fn_id:path}",
68
+ response_model=ScoringFn,
69
+ summary="Get a scoring function by its ID.",
70
+ description="Get a scoring function by its ID.",
71
+ responses={
72
+ 200: {"description": "A ScoringFn."},
73
+ },
74
+ )
75
+ async def get_scoring_function(
76
+ request: Annotated[GetScoringFunctionRequest, Depends(get_get_scoring_function_request)],
77
+ ) -> ScoringFn:
78
+ return await impl.get_scoring_function(request)
79
+
80
+ @router.post(
81
+ "/scoring-functions",
82
+ summary="Register a scoring function.",
83
+ description="Register a scoring function.",
84
+ responses={
85
+ 200: {"description": "The scoring function was successfully registered."},
86
+ },
87
+ deprecated=True,
88
+ )
89
+ async def register_scoring_function(
90
+ request: Annotated[RegisterScoringFunctionRequest, Body(...)],
91
+ ) -> None:
92
+ return await impl.register_scoring_function(request)
93
+
94
+ @router.delete(
95
+ "/scoring-functions/{scoring_fn_id:path}",
96
+ summary="Unregister a scoring function.",
97
+ description="Unregister a scoring function.",
98
+ responses={
99
+ 200: {"description": "The scoring function was successfully unregistered."},
100
+ },
101
+ deprecated=True,
102
+ )
103
+ async def unregister_scoring_function(
104
+ request: Annotated[UnregisterScoringFunctionRequest, Depends(get_unregister_scoring_function_request)],
105
+ ) -> None:
106
+ return await impl.unregister_scoring_function(request)
107
+
108
+ return router
@@ -0,0 +1,214 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Pydantic models for ScoringFunctions API requests and responses.
8
+
9
+ This module defines the request and response models for the ScoringFunctions API
10
+ using Pydantic with Field descriptions for OpenAPI schema generation.
11
+ """
12
+
13
+ from enum import StrEnum
14
+ from typing import Annotated, Any, Literal
15
+
16
+ from pydantic import BaseModel, Field
17
+
18
+ from llama_stack_api.common.type_system import ParamType
19
+ from llama_stack_api.resource import Resource, ResourceType
20
+ from llama_stack_api.schema_utils import json_schema_type, register_schema
21
+
22
+
23
+ @json_schema_type
24
+ class ScoringFnParamsType(StrEnum):
25
+ """Types of scoring function parameter configurations.
26
+ :cvar llm_as_judge: Use an LLM model to evaluate and score responses
27
+ :cvar regex_parser: Use regex patterns to extract and score specific parts of responses
28
+ :cvar basic: Basic scoring with simple aggregation functions
29
+ """
30
+
31
+ llm_as_judge = "llm_as_judge"
32
+ regex_parser = "regex_parser"
33
+ basic = "basic"
34
+
35
+
36
+ @json_schema_type
37
+ class AggregationFunctionType(StrEnum):
38
+ """Types of aggregation functions for scoring results.
39
+ :cvar average: Calculate the arithmetic mean of scores
40
+ :cvar weighted_average: Calculate a weighted average of scores
41
+ :cvar median: Calculate the median value of scores
42
+ :cvar categorical_count: Count occurrences of categorical values
43
+ :cvar accuracy: Calculate accuracy as the proportion of correct answers
44
+ """
45
+
46
+ average = "average"
47
+ weighted_average = "weighted_average"
48
+ median = "median"
49
+ categorical_count = "categorical_count"
50
+ accuracy = "accuracy"
51
+
52
+
53
+ @json_schema_type
54
+ class LLMAsJudgeScoringFnParams(BaseModel):
55
+ """Parameters for LLM-as-judge scoring function configuration.
56
+ :param type: The type of scoring function parameters, always llm_as_judge
57
+ :param judge_model: Identifier of the LLM model to use as a judge for scoring
58
+ :param prompt_template: (Optional) Custom prompt template for the judge model
59
+ :param judge_score_regexes: Regexes to extract the answer from generated response
60
+ :param aggregation_functions: Aggregation functions to apply to the scores of each row
61
+ """
62
+
63
+ type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
64
+ judge_model: str
65
+ prompt_template: str | None = None
66
+ judge_score_regexes: list[str] = Field(
67
+ description="Regexes to extract the answer from generated response",
68
+ default_factory=lambda: [],
69
+ )
70
+ aggregation_functions: list[AggregationFunctionType] = Field(
71
+ description="Aggregation functions to apply to the scores of each row",
72
+ default_factory=lambda: [],
73
+ )
74
+
75
+
76
+ @json_schema_type
77
+ class RegexParserScoringFnParams(BaseModel):
78
+ """Parameters for regex parser scoring function configuration.
79
+ :param type: The type of scoring function parameters, always regex_parser
80
+ :param parsing_regexes: Regex to extract the answer from generated response
81
+ :param aggregation_functions: Aggregation functions to apply to the scores of each row
82
+ """
83
+
84
+ type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
85
+ parsing_regexes: list[str] = Field(
86
+ description="Regex to extract the answer from generated response",
87
+ default_factory=lambda: [],
88
+ )
89
+ aggregation_functions: list[AggregationFunctionType] = Field(
90
+ description="Aggregation functions to apply to the scores of each row",
91
+ default_factory=lambda: [],
92
+ )
93
+
94
+
95
+ @json_schema_type
96
+ class BasicScoringFnParams(BaseModel):
97
+ """Parameters for basic scoring function configuration.
98
+ :param type: The type of scoring function parameters, always basic
99
+ :param aggregation_functions: Aggregation functions to apply to the scores of each row
100
+ """
101
+
102
+ type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
103
+ aggregation_functions: list[AggregationFunctionType] = Field(
104
+ description="Aggregation functions to apply to the scores of each row",
105
+ default_factory=list,
106
+ )
107
+
108
+
109
+ ScoringFnParams = Annotated[
110
+ LLMAsJudgeScoringFnParams | RegexParserScoringFnParams | BasicScoringFnParams,
111
+ Field(discriminator="type"),
112
+ ]
113
+ register_schema(ScoringFnParams, name="ScoringFnParams")
114
+
115
+
116
+ @json_schema_type
117
+ class ListScoringFunctionsRequest(BaseModel):
118
+ """Request model for listing scoring functions."""
119
+
120
+ pass
121
+
122
+
123
+ @json_schema_type
124
+ class GetScoringFunctionRequest(BaseModel):
125
+ """Request model for getting a scoring function."""
126
+
127
+ scoring_fn_id: str = Field(..., description="The ID of the scoring function to get.")
128
+
129
+
130
+ @json_schema_type
131
+ class RegisterScoringFunctionRequest(BaseModel):
132
+ """Request model for registering a scoring function."""
133
+
134
+ scoring_fn_id: str = Field(..., description="The ID of the scoring function to register.")
135
+ description: str = Field(..., description="The description of the scoring function.")
136
+ return_type: ParamType = Field(..., description="The return type of the scoring function.")
137
+ provider_scoring_fn_id: str | None = Field(
138
+ default=None, description="The ID of the provider scoring function to use for the scoring function."
139
+ )
140
+ provider_id: str | None = Field(default=None, description="The ID of the provider to use for the scoring function.")
141
+ params: ScoringFnParams | None = Field(
142
+ default=None,
143
+ description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval.",
144
+ )
145
+
146
+
147
+ @json_schema_type
148
+ class UnregisterScoringFunctionRequest(BaseModel):
149
+ """Request model for unregistering a scoring function."""
150
+
151
+ scoring_fn_id: str = Field(..., description="The ID of the scoring function to unregister.")
152
+
153
+
154
+ class CommonScoringFnFields(BaseModel):
155
+ description: str | None = None
156
+ metadata: dict[str, Any] = Field(
157
+ default_factory=dict,
158
+ description="Any additional metadata for this definition",
159
+ )
160
+ return_type: ParamType = Field(
161
+ description="The return type of the deterministic function",
162
+ )
163
+ params: ScoringFnParams | None = Field(
164
+ description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
165
+ default=None,
166
+ )
167
+
168
+
169
+ @json_schema_type
170
+ class ScoringFn(CommonScoringFnFields, Resource):
171
+ """A scoring function resource for evaluating model outputs.
172
+ :param type: The resource type, always scoring_function
173
+ """
174
+
175
+ type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
176
+
177
+ @property
178
+ def scoring_fn_id(self) -> str:
179
+ return self.identifier
180
+
181
+ @property
182
+ def provider_scoring_fn_id(self) -> str | None:
183
+ return self.provider_resource_id
184
+
185
+
186
+ class ScoringFnInput(CommonScoringFnFields, BaseModel):
187
+ scoring_fn_id: str
188
+ provider_id: str | None = None
189
+ provider_scoring_fn_id: str | None = None
190
+
191
+
192
+ @json_schema_type
193
+ class ListScoringFunctionsResponse(BaseModel):
194
+ """Response containing a list of scoring function objects."""
195
+
196
+ data: list[ScoringFn] = Field(..., description="List of scoring function objects.")
197
+
198
+
199
+ __all__ = [
200
+ "ScoringFnParamsType",
201
+ "AggregationFunctionType",
202
+ "LLMAsJudgeScoringFnParams",
203
+ "RegexParserScoringFnParams",
204
+ "BasicScoringFnParams",
205
+ "ScoringFnParams",
206
+ "ListScoringFunctionsRequest",
207
+ "GetScoringFunctionRequest",
208
+ "RegisterScoringFunctionRequest",
209
+ "UnregisterScoringFunctionRequest",
210
+ "CommonScoringFnFields",
211
+ "ScoringFn",
212
+ "ScoringFnInput",
213
+ "ListScoringFunctionsResponse",
214
+ ]