llama-stack-api 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack_api/__init__.py +175 -20
- llama_stack_api/agents/__init__.py +38 -0
- llama_stack_api/agents/api.py +52 -0
- llama_stack_api/agents/fastapi_routes.py +268 -0
- llama_stack_api/agents/models.py +181 -0
- llama_stack_api/common/errors.py +15 -0
- llama_stack_api/connectors/__init__.py +38 -0
- llama_stack_api/connectors/api.py +50 -0
- llama_stack_api/connectors/fastapi_routes.py +103 -0
- llama_stack_api/connectors/models.py +103 -0
- llama_stack_api/conversations/__init__.py +61 -0
- llama_stack_api/conversations/api.py +44 -0
- llama_stack_api/conversations/fastapi_routes.py +177 -0
- llama_stack_api/conversations/models.py +245 -0
- llama_stack_api/datasetio/__init__.py +34 -0
- llama_stack_api/datasetio/api.py +42 -0
- llama_stack_api/datasetio/fastapi_routes.py +94 -0
- llama_stack_api/datasetio/models.py +48 -0
- llama_stack_api/eval/__init__.py +55 -0
- llama_stack_api/eval/api.py +51 -0
- llama_stack_api/eval/compat.py +300 -0
- llama_stack_api/eval/fastapi_routes.py +126 -0
- llama_stack_api/eval/models.py +141 -0
- llama_stack_api/inference/__init__.py +207 -0
- llama_stack_api/inference/api.py +93 -0
- llama_stack_api/inference/fastapi_routes.py +243 -0
- llama_stack_api/inference/models.py +1035 -0
- llama_stack_api/models/__init__.py +47 -0
- llama_stack_api/models/api.py +38 -0
- llama_stack_api/models/fastapi_routes.py +104 -0
- llama_stack_api/{models.py → models/models.py} +65 -79
- llama_stack_api/openai_responses.py +32 -6
- llama_stack_api/post_training/__init__.py +73 -0
- llama_stack_api/post_training/api.py +36 -0
- llama_stack_api/post_training/fastapi_routes.py +116 -0
- llama_stack_api/{post_training.py → post_training/models.py} +55 -86
- llama_stack_api/prompts/__init__.py +47 -0
- llama_stack_api/prompts/api.py +44 -0
- llama_stack_api/prompts/fastapi_routes.py +163 -0
- llama_stack_api/prompts/models.py +177 -0
- llama_stack_api/resource.py +0 -1
- llama_stack_api/safety/__init__.py +37 -0
- llama_stack_api/safety/api.py +29 -0
- llama_stack_api/safety/datatypes.py +83 -0
- llama_stack_api/safety/fastapi_routes.py +55 -0
- llama_stack_api/safety/models.py +38 -0
- llama_stack_api/schema_utils.py +47 -4
- llama_stack_api/scoring/__init__.py +66 -0
- llama_stack_api/scoring/api.py +35 -0
- llama_stack_api/scoring/fastapi_routes.py +67 -0
- llama_stack_api/scoring/models.py +81 -0
- llama_stack_api/scoring_functions/__init__.py +50 -0
- llama_stack_api/scoring_functions/api.py +39 -0
- llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
- llama_stack_api/{scoring_functions.py → scoring_functions/models.py} +67 -64
- llama_stack_api/shields/__init__.py +41 -0
- llama_stack_api/shields/api.py +39 -0
- llama_stack_api/shields/fastapi_routes.py +104 -0
- llama_stack_api/shields/models.py +74 -0
- llama_stack_api/validators.py +46 -0
- llama_stack_api/vector_io/__init__.py +88 -0
- llama_stack_api/vector_io/api.py +234 -0
- llama_stack_api/vector_io/fastapi_routes.py +447 -0
- llama_stack_api/{vector_io.py → vector_io/models.py} +99 -377
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
- llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
- llama_stack_api/agents.py +0 -173
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/eval.py +0 -137
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/safety.py +0 -132
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/shields.py +0 -93
- llama_stack_api-0.4.4.dist-info/RECORD +0 -70
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Annotated
|
|
8
|
+
|
|
9
|
+
from fastapi import APIRouter, Body, Depends
|
|
10
|
+
|
|
11
|
+
from llama_stack_api.common.job_types import Job
|
|
12
|
+
from llama_stack_api.router_utils import create_path_dependency, standard_responses
|
|
13
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1ALPHA
|
|
14
|
+
|
|
15
|
+
from .api import Eval
|
|
16
|
+
from .models import (
|
|
17
|
+
BenchmarkIdRequest,
|
|
18
|
+
EvaluateResponse,
|
|
19
|
+
EvaluateRowsBodyRequest,
|
|
20
|
+
EvaluateRowsRequest,
|
|
21
|
+
JobCancelRequest,
|
|
22
|
+
JobResultRequest,
|
|
23
|
+
JobStatusRequest,
|
|
24
|
+
RunEvalBodyRequest,
|
|
25
|
+
RunEvalRequest,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
get_benchmark_id_request = create_path_dependency(BenchmarkIdRequest)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_router(impl: Eval) -> APIRouter:
|
|
32
|
+
"""Create a FastAPI router for the Eval API."""
|
|
33
|
+
router = APIRouter(
|
|
34
|
+
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
|
|
35
|
+
tags=["Eval"],
|
|
36
|
+
responses=standard_responses,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@router.post(
|
|
40
|
+
"/eval/benchmarks/{benchmark_id}/jobs",
|
|
41
|
+
response_model=Job,
|
|
42
|
+
summary="Run Eval",
|
|
43
|
+
description="Run an evaluation on a benchmark.",
|
|
44
|
+
responses={
|
|
45
|
+
200: {"description": "The job that was created to run the evaluation."},
|
|
46
|
+
},
|
|
47
|
+
)
|
|
48
|
+
async def run_eval(
|
|
49
|
+
benchmark_id_request: Annotated[BenchmarkIdRequest, Depends(get_benchmark_id_request)],
|
|
50
|
+
body_request: Annotated[RunEvalBodyRequest, Body(...)],
|
|
51
|
+
) -> Job:
|
|
52
|
+
request = RunEvalRequest(
|
|
53
|
+
benchmark_id=benchmark_id_request.benchmark_id,
|
|
54
|
+
benchmark_config=body_request.benchmark_config,
|
|
55
|
+
)
|
|
56
|
+
return await impl.run_eval(request)
|
|
57
|
+
|
|
58
|
+
@router.post(
|
|
59
|
+
"/eval/benchmarks/{benchmark_id}/evaluations",
|
|
60
|
+
response_model=EvaluateResponse,
|
|
61
|
+
summary="Evaluate Rows",
|
|
62
|
+
description="Evaluate a list of rows on a benchmark.",
|
|
63
|
+
responses={
|
|
64
|
+
200: {"description": "EvaluateResponse object containing generations and scores."},
|
|
65
|
+
},
|
|
66
|
+
)
|
|
67
|
+
async def evaluate_rows(
|
|
68
|
+
benchmark_id_request: Annotated[BenchmarkIdRequest, Depends(get_benchmark_id_request)],
|
|
69
|
+
body_request: Annotated[EvaluateRowsBodyRequest, Body(...)],
|
|
70
|
+
) -> EvaluateResponse:
|
|
71
|
+
request = EvaluateRowsRequest(
|
|
72
|
+
benchmark_id=benchmark_id_request.benchmark_id,
|
|
73
|
+
input_rows=body_request.input_rows,
|
|
74
|
+
scoring_functions=body_request.scoring_functions,
|
|
75
|
+
benchmark_config=body_request.benchmark_config,
|
|
76
|
+
)
|
|
77
|
+
return await impl.evaluate_rows(request)
|
|
78
|
+
|
|
79
|
+
@router.get(
|
|
80
|
+
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
|
81
|
+
response_model=Job,
|
|
82
|
+
summary="Job Status",
|
|
83
|
+
description="Get the status of a job.",
|
|
84
|
+
responses={
|
|
85
|
+
200: {"description": "The status of the evaluation job."},
|
|
86
|
+
},
|
|
87
|
+
)
|
|
88
|
+
async def job_status(
|
|
89
|
+
benchmark_id: str,
|
|
90
|
+
job_id: str,
|
|
91
|
+
) -> Job:
|
|
92
|
+
request = JobStatusRequest(benchmark_id=benchmark_id, job_id=job_id)
|
|
93
|
+
return await impl.job_status(request)
|
|
94
|
+
|
|
95
|
+
@router.delete(
|
|
96
|
+
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
|
97
|
+
summary="Job Cancel",
|
|
98
|
+
description="Cancel a job.",
|
|
99
|
+
responses={
|
|
100
|
+
200: {"description": "Successful Response"},
|
|
101
|
+
},
|
|
102
|
+
)
|
|
103
|
+
async def job_cancel(
|
|
104
|
+
benchmark_id: str,
|
|
105
|
+
job_id: str,
|
|
106
|
+
) -> None:
|
|
107
|
+
request = JobCancelRequest(benchmark_id=benchmark_id, job_id=job_id)
|
|
108
|
+
return await impl.job_cancel(request)
|
|
109
|
+
|
|
110
|
+
@router.get(
|
|
111
|
+
"/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
|
|
112
|
+
response_model=EvaluateResponse,
|
|
113
|
+
summary="Job Result",
|
|
114
|
+
description="Get the result of a job.",
|
|
115
|
+
responses={
|
|
116
|
+
200: {"description": "The result of the job."},
|
|
117
|
+
},
|
|
118
|
+
)
|
|
119
|
+
async def job_result(
|
|
120
|
+
benchmark_id: str,
|
|
121
|
+
job_id: str,
|
|
122
|
+
) -> EvaluateResponse:
|
|
123
|
+
request = JobResultRequest(benchmark_id=benchmark_id, job_id=job_id)
|
|
124
|
+
return await impl.job_result(request)
|
|
125
|
+
|
|
126
|
+
return router
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
from llama_stack_api.inference import SamplingParams, SystemMessage
|
|
12
|
+
from llama_stack_api.schema_utils import json_schema_type
|
|
13
|
+
from llama_stack_api.scoring import ScoringResult
|
|
14
|
+
from llama_stack_api.scoring_functions import ScoringFnParams
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@json_schema_type
|
|
18
|
+
class ModelCandidate(BaseModel):
|
|
19
|
+
"""A model candidate for evaluation."""
|
|
20
|
+
|
|
21
|
+
type: Literal["model"] = "model"
|
|
22
|
+
model: str = Field(..., description="The model ID to evaluate", min_length=1)
|
|
23
|
+
sampling_params: SamplingParams = Field(..., description="The sampling parameters for the model")
|
|
24
|
+
system_message: SystemMessage | None = Field(
|
|
25
|
+
None, description="The system message providing instructions or context to the model"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
EvalCandidate = ModelCandidate
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@json_schema_type
|
|
33
|
+
class BenchmarkConfig(BaseModel):
|
|
34
|
+
"""A benchmark configuration for evaluation."""
|
|
35
|
+
|
|
36
|
+
eval_candidate: EvalCandidate = Field(..., description="The candidate to evaluate")
|
|
37
|
+
scoring_params: dict[str, ScoringFnParams] = Field(
|
|
38
|
+
default_factory=dict,
|
|
39
|
+
description="Map between scoring function id and parameters for each scoring function you want to run",
|
|
40
|
+
)
|
|
41
|
+
num_examples: int | None = Field(
|
|
42
|
+
None,
|
|
43
|
+
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
|
|
44
|
+
ge=1,
|
|
45
|
+
)
|
|
46
|
+
# we could optinally add any specific dataset config here
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@json_schema_type
|
|
50
|
+
class EvaluateResponse(BaseModel):
|
|
51
|
+
"""The response from an evaluation."""
|
|
52
|
+
|
|
53
|
+
generations: list[dict[str, Any]] = Field(..., description="The generations from the evaluation")
|
|
54
|
+
scores: dict[str, ScoringResult] = Field(
|
|
55
|
+
..., description="The scores from the evaluation. Each key in the dict is a scoring function name"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@json_schema_type
|
|
60
|
+
class BenchmarkIdRequest(BaseModel):
|
|
61
|
+
"""Request model containing benchmark_id path parameter."""
|
|
62
|
+
|
|
63
|
+
benchmark_id: str = Field(..., description="The ID of the benchmark", min_length=1)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@json_schema_type
|
|
67
|
+
class RunEvalRequest(BaseModel):
|
|
68
|
+
"""Request model for running an evaluation on a benchmark."""
|
|
69
|
+
|
|
70
|
+
benchmark_id: str = Field(..., description="The ID of the benchmark to run the evaluation on", min_length=1)
|
|
71
|
+
benchmark_config: BenchmarkConfig = Field(..., description="The configuration for the benchmark")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@json_schema_type
|
|
75
|
+
class RunEvalBodyRequest(BaseModel):
|
|
76
|
+
"""Request body model for running an evaluation (without path parameter)."""
|
|
77
|
+
|
|
78
|
+
benchmark_config: BenchmarkConfig = Field(..., description="The configuration for the benchmark")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@json_schema_type
|
|
82
|
+
class EvaluateRowsRequest(BaseModel):
|
|
83
|
+
"""Request model for evaluating a list of rows on a benchmark."""
|
|
84
|
+
|
|
85
|
+
benchmark_id: str = Field(..., description="The ID of the benchmark to run the evaluation on", min_length=1)
|
|
86
|
+
input_rows: list[dict[str, Any]] = Field(..., description="The rows to evaluate", min_length=1)
|
|
87
|
+
scoring_functions: list[str] = Field(
|
|
88
|
+
..., description="The scoring functions to use for the evaluation", min_length=1
|
|
89
|
+
)
|
|
90
|
+
benchmark_config: BenchmarkConfig = Field(..., description="The configuration for the benchmark")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@json_schema_type
|
|
94
|
+
class EvaluateRowsBodyRequest(BaseModel):
|
|
95
|
+
"""Request body model for evaluating rows (without path parameter)."""
|
|
96
|
+
|
|
97
|
+
input_rows: list[dict[str, Any]] = Field(..., description="The rows to evaluate", min_length=1)
|
|
98
|
+
scoring_functions: list[str] = Field(
|
|
99
|
+
..., description="The scoring functions to use for the evaluation", min_length=1
|
|
100
|
+
)
|
|
101
|
+
benchmark_config: BenchmarkConfig = Field(..., description="The configuration for the benchmark")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@json_schema_type
|
|
105
|
+
class JobStatusRequest(BaseModel):
|
|
106
|
+
"""Request model for getting the status of a job."""
|
|
107
|
+
|
|
108
|
+
benchmark_id: str = Field(..., description="The ID of the benchmark associated with the job", min_length=1)
|
|
109
|
+
job_id: str = Field(..., description="The ID of the job to get the status of", min_length=1)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@json_schema_type
|
|
113
|
+
class JobCancelRequest(BaseModel):
|
|
114
|
+
"""Request model for canceling a job."""
|
|
115
|
+
|
|
116
|
+
benchmark_id: str = Field(..., description="The ID of the benchmark associated with the job", min_length=1)
|
|
117
|
+
job_id: str = Field(..., description="The ID of the job to cancel", min_length=1)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@json_schema_type
|
|
121
|
+
class JobResultRequest(BaseModel):
|
|
122
|
+
"""Request model for getting the result of a job."""
|
|
123
|
+
|
|
124
|
+
benchmark_id: str = Field(..., description="The ID of the benchmark associated with the job", min_length=1)
|
|
125
|
+
job_id: str = Field(..., description="The ID of the job to get the result of", min_length=1)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
__all__ = [
|
|
129
|
+
"ModelCandidate",
|
|
130
|
+
"EvalCandidate",
|
|
131
|
+
"BenchmarkConfig",
|
|
132
|
+
"EvaluateResponse",
|
|
133
|
+
"BenchmarkIdRequest",
|
|
134
|
+
"RunEvalRequest",
|
|
135
|
+
"RunEvalBodyRequest",
|
|
136
|
+
"EvaluateRowsRequest",
|
|
137
|
+
"EvaluateRowsBodyRequest",
|
|
138
|
+
"JobStatusRequest",
|
|
139
|
+
"JobCancelRequest",
|
|
140
|
+
"JobResultRequest",
|
|
141
|
+
]
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Inference API protocol and models.
|
|
8
|
+
|
|
9
|
+
This module contains the Inference protocol definition.
|
|
10
|
+
Pydantic models are defined in llama_stack_api.inference.models.
|
|
11
|
+
The FastAPI router is defined in llama_stack_api.inference.fastapi_routes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Import common types for backward compatibility
|
|
15
|
+
# (these were previously available from the old inference.py)
|
|
16
|
+
from llama_stack_api.common.content_types import InterleavedContent
|
|
17
|
+
|
|
18
|
+
# Import fastapi_routes for router factory access
|
|
19
|
+
from . import fastapi_routes
|
|
20
|
+
|
|
21
|
+
# Import protocol for re-export
|
|
22
|
+
from .api import Inference, InferenceProvider, ModelStore
|
|
23
|
+
|
|
24
|
+
# Import models for re-export
|
|
25
|
+
from .models import (
|
|
26
|
+
AllowedToolsConfig,
|
|
27
|
+
Bf16QuantizationConfig,
|
|
28
|
+
ChatCompletionResponseEventType,
|
|
29
|
+
CompletionRequest,
|
|
30
|
+
CustomToolConfig,
|
|
31
|
+
EmbeddingsResponse,
|
|
32
|
+
EmbeddingTaskType,
|
|
33
|
+
Fp8QuantizationConfig,
|
|
34
|
+
FunctionToolConfig,
|
|
35
|
+
GetChatCompletionRequest,
|
|
36
|
+
GrammarResponseFormat,
|
|
37
|
+
GreedySamplingStrategy,
|
|
38
|
+
Int4QuantizationConfig,
|
|
39
|
+
JsonSchemaResponseFormat,
|
|
40
|
+
ListChatCompletionsRequest,
|
|
41
|
+
ListOpenAIChatCompletionResponse,
|
|
42
|
+
LogProbConfig,
|
|
43
|
+
OpenAIAssistantMessageParam,
|
|
44
|
+
OpenAIChatCompletion,
|
|
45
|
+
OpenAIChatCompletionChunk,
|
|
46
|
+
OpenAIChatCompletionContentPartImageParam,
|
|
47
|
+
OpenAIChatCompletionContentPartParam,
|
|
48
|
+
OpenAIChatCompletionContentPartTextParam,
|
|
49
|
+
OpenAIChatCompletionMessageContent,
|
|
50
|
+
OpenAIChatCompletionRequestWithExtraBody,
|
|
51
|
+
OpenAIChatCompletionTextOnlyMessageContent,
|
|
52
|
+
OpenAIChatCompletionToolCall,
|
|
53
|
+
OpenAIChatCompletionToolCallFunction,
|
|
54
|
+
OpenAIChatCompletionToolChoice,
|
|
55
|
+
OpenAIChatCompletionToolChoiceAllowedTools,
|
|
56
|
+
OpenAIChatCompletionToolChoiceCustomTool,
|
|
57
|
+
OpenAIChatCompletionToolChoiceFunctionTool,
|
|
58
|
+
OpenAIChatCompletionUsage,
|
|
59
|
+
OpenAIChatCompletionUsageCompletionTokensDetails,
|
|
60
|
+
OpenAIChatCompletionUsagePromptTokensDetails,
|
|
61
|
+
OpenAIChoice,
|
|
62
|
+
OpenAIChoiceDelta,
|
|
63
|
+
OpenAIChoiceLogprobs,
|
|
64
|
+
OpenAIChunkChoice,
|
|
65
|
+
OpenAICompletion,
|
|
66
|
+
OpenAICompletionChoice,
|
|
67
|
+
OpenAICompletionLogprobs,
|
|
68
|
+
OpenAICompletionRequestWithExtraBody,
|
|
69
|
+
OpenAICompletionWithInputMessages,
|
|
70
|
+
OpenAIDeveloperMessageParam,
|
|
71
|
+
OpenAIEmbeddingData,
|
|
72
|
+
OpenAIEmbeddingsRequestWithExtraBody,
|
|
73
|
+
OpenAIEmbeddingsResponse,
|
|
74
|
+
OpenAIEmbeddingUsage,
|
|
75
|
+
OpenAIFile,
|
|
76
|
+
OpenAIFileFile,
|
|
77
|
+
OpenAIFinishReason,
|
|
78
|
+
OpenAIImageURL,
|
|
79
|
+
OpenAIJSONSchema,
|
|
80
|
+
OpenAIMessageParam,
|
|
81
|
+
OpenAIResponseFormatJSONObject,
|
|
82
|
+
OpenAIResponseFormatJSONSchema,
|
|
83
|
+
OpenAIResponseFormatParam,
|
|
84
|
+
OpenAIResponseFormatText,
|
|
85
|
+
OpenAISystemMessageParam,
|
|
86
|
+
OpenAITokenLogProb,
|
|
87
|
+
OpenAIToolMessageParam,
|
|
88
|
+
OpenAITopLogProb,
|
|
89
|
+
OpenAIUserMessageParam,
|
|
90
|
+
QuantizationConfig,
|
|
91
|
+
QuantizationType,
|
|
92
|
+
RerankData,
|
|
93
|
+
RerankRequest,
|
|
94
|
+
RerankResponse,
|
|
95
|
+
ResponseFormat,
|
|
96
|
+
ResponseFormatType,
|
|
97
|
+
SamplingParams,
|
|
98
|
+
SamplingStrategy,
|
|
99
|
+
SystemMessage,
|
|
100
|
+
SystemMessageBehavior,
|
|
101
|
+
TextTruncation,
|
|
102
|
+
TokenLogProbs,
|
|
103
|
+
ToolChoice,
|
|
104
|
+
ToolResponseMessage,
|
|
105
|
+
TopKSamplingStrategy,
|
|
106
|
+
TopPSamplingStrategy,
|
|
107
|
+
UserMessage,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
__all__ = [
|
|
111
|
+
# Protocol
|
|
112
|
+
"Inference",
|
|
113
|
+
"InferenceProvider",
|
|
114
|
+
"ModelStore",
|
|
115
|
+
# Common types (for backward compatibility)
|
|
116
|
+
"InterleavedContent",
|
|
117
|
+
# Sampling
|
|
118
|
+
"GreedySamplingStrategy",
|
|
119
|
+
"TopPSamplingStrategy",
|
|
120
|
+
"TopKSamplingStrategy",
|
|
121
|
+
"SamplingStrategy",
|
|
122
|
+
"SamplingParams",
|
|
123
|
+
"LogProbConfig",
|
|
124
|
+
# Quantization
|
|
125
|
+
"QuantizationType",
|
|
126
|
+
"Fp8QuantizationConfig",
|
|
127
|
+
"Bf16QuantizationConfig",
|
|
128
|
+
"Int4QuantizationConfig",
|
|
129
|
+
"QuantizationConfig",
|
|
130
|
+
# Messages
|
|
131
|
+
"UserMessage",
|
|
132
|
+
"SystemMessage",
|
|
133
|
+
"ToolResponseMessage",
|
|
134
|
+
"ToolChoice",
|
|
135
|
+
"TokenLogProbs",
|
|
136
|
+
# Response
|
|
137
|
+
"ChatCompletionResponseEventType",
|
|
138
|
+
"ResponseFormatType",
|
|
139
|
+
"JsonSchemaResponseFormat",
|
|
140
|
+
"GrammarResponseFormat",
|
|
141
|
+
"ResponseFormat",
|
|
142
|
+
"CompletionRequest",
|
|
143
|
+
"SystemMessageBehavior",
|
|
144
|
+
"EmbeddingsResponse",
|
|
145
|
+
"RerankData",
|
|
146
|
+
"RerankResponse",
|
|
147
|
+
# OpenAI Compatibility
|
|
148
|
+
"OpenAIChatCompletionContentPartTextParam",
|
|
149
|
+
"OpenAIImageURL",
|
|
150
|
+
"OpenAIChatCompletionContentPartImageParam",
|
|
151
|
+
"OpenAIFileFile",
|
|
152
|
+
"OpenAIFile",
|
|
153
|
+
"OpenAIChatCompletionContentPartParam",
|
|
154
|
+
"OpenAIChatCompletionMessageContent",
|
|
155
|
+
"OpenAIChatCompletionTextOnlyMessageContent",
|
|
156
|
+
"OpenAIUserMessageParam",
|
|
157
|
+
"OpenAISystemMessageParam",
|
|
158
|
+
"OpenAIChatCompletionToolCallFunction",
|
|
159
|
+
"OpenAIChatCompletionToolCall",
|
|
160
|
+
"OpenAIAssistantMessageParam",
|
|
161
|
+
"OpenAIToolMessageParam",
|
|
162
|
+
"OpenAIDeveloperMessageParam",
|
|
163
|
+
"OpenAIMessageParam",
|
|
164
|
+
"OpenAIResponseFormatText",
|
|
165
|
+
"OpenAIJSONSchema",
|
|
166
|
+
"OpenAIResponseFormatJSONSchema",
|
|
167
|
+
"OpenAIResponseFormatJSONObject",
|
|
168
|
+
"OpenAIResponseFormatParam",
|
|
169
|
+
"FunctionToolConfig",
|
|
170
|
+
"OpenAIChatCompletionToolChoiceFunctionTool",
|
|
171
|
+
"CustomToolConfig",
|
|
172
|
+
"OpenAIChatCompletionToolChoiceCustomTool",
|
|
173
|
+
"AllowedToolsConfig",
|
|
174
|
+
"OpenAIChatCompletionToolChoiceAllowedTools",
|
|
175
|
+
"OpenAIChatCompletionToolChoice",
|
|
176
|
+
"OpenAITopLogProb",
|
|
177
|
+
"OpenAITokenLogProb",
|
|
178
|
+
"OpenAIChoiceLogprobs",
|
|
179
|
+
"OpenAIChoiceDelta",
|
|
180
|
+
"OpenAIChunkChoice",
|
|
181
|
+
"OpenAIChoice",
|
|
182
|
+
"OpenAIChatCompletionUsageCompletionTokensDetails",
|
|
183
|
+
"OpenAIChatCompletionUsagePromptTokensDetails",
|
|
184
|
+
"OpenAIChatCompletionUsage",
|
|
185
|
+
"OpenAIChatCompletion",
|
|
186
|
+
"OpenAIChatCompletionChunk",
|
|
187
|
+
"OpenAICompletionLogprobs",
|
|
188
|
+
"OpenAICompletionChoice",
|
|
189
|
+
"OpenAICompletion",
|
|
190
|
+
"OpenAIFinishReason",
|
|
191
|
+
"OpenAIEmbeddingData",
|
|
192
|
+
"OpenAIEmbeddingUsage",
|
|
193
|
+
"OpenAIEmbeddingsResponse",
|
|
194
|
+
"TextTruncation",
|
|
195
|
+
"EmbeddingTaskType",
|
|
196
|
+
"OpenAICompletionWithInputMessages",
|
|
197
|
+
"ListOpenAIChatCompletionResponse",
|
|
198
|
+
"OpenAICompletionRequestWithExtraBody",
|
|
199
|
+
"OpenAIChatCompletionRequestWithExtraBody",
|
|
200
|
+
"OpenAIEmbeddingsRequestWithExtraBody",
|
|
201
|
+
# Request Models
|
|
202
|
+
"ListChatCompletionsRequest",
|
|
203
|
+
"GetChatCompletionRequest",
|
|
204
|
+
"RerankRequest",
|
|
205
|
+
# Router factory module
|
|
206
|
+
"fastapi_routes",
|
|
207
|
+
]
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from collections.abc import AsyncIterator
|
|
8
|
+
from typing import Protocol, runtime_checkable
|
|
9
|
+
|
|
10
|
+
from llama_stack_api.models import Model
|
|
11
|
+
|
|
12
|
+
from .models import (
|
|
13
|
+
GetChatCompletionRequest,
|
|
14
|
+
ListChatCompletionsRequest,
|
|
15
|
+
ListOpenAIChatCompletionResponse,
|
|
16
|
+
OpenAIChatCompletion,
|
|
17
|
+
OpenAIChatCompletionChunk,
|
|
18
|
+
OpenAIChatCompletionRequestWithExtraBody,
|
|
19
|
+
OpenAICompletion,
|
|
20
|
+
OpenAICompletionRequestWithExtraBody,
|
|
21
|
+
OpenAICompletionWithInputMessages,
|
|
22
|
+
OpenAIEmbeddingsRequestWithExtraBody,
|
|
23
|
+
OpenAIEmbeddingsResponse,
|
|
24
|
+
RerankRequest,
|
|
25
|
+
RerankResponse,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ModelStore(Protocol):
|
|
30
|
+
async def get_model(self, identifier: str) -> Model: ...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@runtime_checkable
|
|
34
|
+
class InferenceProvider(Protocol):
|
|
35
|
+
"""
|
|
36
|
+
This protocol defines the interface that should be implemented by all inference providers.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
API_NAMESPACE: str = "Inference"
|
|
40
|
+
|
|
41
|
+
model_store: ModelStore | None = None
|
|
42
|
+
|
|
43
|
+
async def rerank(
|
|
44
|
+
self,
|
|
45
|
+
request: RerankRequest,
|
|
46
|
+
) -> RerankResponse:
|
|
47
|
+
"""Rerank a list of documents based on their relevance to a query."""
|
|
48
|
+
raise NotImplementedError("Reranking is not implemented")
|
|
49
|
+
return # this is so mypy's safe-super rule will consider the method concrete
|
|
50
|
+
|
|
51
|
+
async def openai_completion(
|
|
52
|
+
self,
|
|
53
|
+
params: OpenAICompletionRequestWithExtraBody,
|
|
54
|
+
) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
|
|
55
|
+
"""Generate an OpenAI-compatible completion for the given prompt using the specified model."""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
async def openai_chat_completion(
|
|
59
|
+
self,
|
|
60
|
+
params: OpenAIChatCompletionRequestWithExtraBody,
|
|
61
|
+
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
62
|
+
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model."""
|
|
63
|
+
...
|
|
64
|
+
|
|
65
|
+
async def openai_embeddings(
|
|
66
|
+
self,
|
|
67
|
+
params: OpenAIEmbeddingsRequestWithExtraBody,
|
|
68
|
+
) -> OpenAIEmbeddingsResponse:
|
|
69
|
+
"""Generate OpenAI-compatible embeddings for the given input using the specified model."""
|
|
70
|
+
...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class Inference(InferenceProvider):
|
|
74
|
+
"""Inference
|
|
75
|
+
|
|
76
|
+
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
|
77
|
+
|
|
78
|
+
This API provides the raw interface to the underlying models. Three kinds of models are supported:
|
|
79
|
+
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
|
80
|
+
- Embedding models: these models generate embeddings to be used for semantic search.
|
|
81
|
+
- Rerank models: these models reorder the documents based on their relevance to a query.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
async def list_chat_completions(
|
|
85
|
+
self,
|
|
86
|
+
request: ListChatCompletionsRequest,
|
|
87
|
+
) -> ListOpenAIChatCompletionResponse:
|
|
88
|
+
"""List stored chat completions."""
|
|
89
|
+
raise NotImplementedError("List chat completions is not implemented")
|
|
90
|
+
|
|
91
|
+
async def get_chat_completion(self, request: GetChatCompletionRequest) -> OpenAICompletionWithInputMessages:
|
|
92
|
+
"""Retrieve a stored chat completion by its ID."""
|
|
93
|
+
raise NotImplementedError("Get chat completion is not implemented")
|