llama-stack-api 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack_api/__init__.py +175 -20
- llama_stack_api/agents/__init__.py +38 -0
- llama_stack_api/agents/api.py +52 -0
- llama_stack_api/agents/fastapi_routes.py +268 -0
- llama_stack_api/agents/models.py +181 -0
- llama_stack_api/common/errors.py +15 -0
- llama_stack_api/connectors/__init__.py +38 -0
- llama_stack_api/connectors/api.py +50 -0
- llama_stack_api/connectors/fastapi_routes.py +103 -0
- llama_stack_api/connectors/models.py +103 -0
- llama_stack_api/conversations/__init__.py +61 -0
- llama_stack_api/conversations/api.py +44 -0
- llama_stack_api/conversations/fastapi_routes.py +177 -0
- llama_stack_api/conversations/models.py +245 -0
- llama_stack_api/datasetio/__init__.py +34 -0
- llama_stack_api/datasetio/api.py +42 -0
- llama_stack_api/datasetio/fastapi_routes.py +94 -0
- llama_stack_api/datasetio/models.py +48 -0
- llama_stack_api/eval/__init__.py +55 -0
- llama_stack_api/eval/api.py +51 -0
- llama_stack_api/eval/compat.py +300 -0
- llama_stack_api/eval/fastapi_routes.py +126 -0
- llama_stack_api/eval/models.py +141 -0
- llama_stack_api/inference/__init__.py +207 -0
- llama_stack_api/inference/api.py +93 -0
- llama_stack_api/inference/fastapi_routes.py +243 -0
- llama_stack_api/inference/models.py +1035 -0
- llama_stack_api/models/__init__.py +47 -0
- llama_stack_api/models/api.py +38 -0
- llama_stack_api/models/fastapi_routes.py +104 -0
- llama_stack_api/{models.py → models/models.py} +65 -79
- llama_stack_api/openai_responses.py +32 -6
- llama_stack_api/post_training/__init__.py +73 -0
- llama_stack_api/post_training/api.py +36 -0
- llama_stack_api/post_training/fastapi_routes.py +116 -0
- llama_stack_api/{post_training.py → post_training/models.py} +55 -86
- llama_stack_api/prompts/__init__.py +47 -0
- llama_stack_api/prompts/api.py +44 -0
- llama_stack_api/prompts/fastapi_routes.py +163 -0
- llama_stack_api/prompts/models.py +177 -0
- llama_stack_api/resource.py +0 -1
- llama_stack_api/safety/__init__.py +37 -0
- llama_stack_api/safety/api.py +29 -0
- llama_stack_api/safety/datatypes.py +83 -0
- llama_stack_api/safety/fastapi_routes.py +55 -0
- llama_stack_api/safety/models.py +38 -0
- llama_stack_api/schema_utils.py +47 -4
- llama_stack_api/scoring/__init__.py +66 -0
- llama_stack_api/scoring/api.py +35 -0
- llama_stack_api/scoring/fastapi_routes.py +67 -0
- llama_stack_api/scoring/models.py +81 -0
- llama_stack_api/scoring_functions/__init__.py +50 -0
- llama_stack_api/scoring_functions/api.py +39 -0
- llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
- llama_stack_api/{scoring_functions.py → scoring_functions/models.py} +67 -64
- llama_stack_api/shields/__init__.py +41 -0
- llama_stack_api/shields/api.py +39 -0
- llama_stack_api/shields/fastapi_routes.py +104 -0
- llama_stack_api/shields/models.py +74 -0
- llama_stack_api/validators.py +46 -0
- llama_stack_api/vector_io/__init__.py +88 -0
- llama_stack_api/vector_io/api.py +234 -0
- llama_stack_api/vector_io/fastapi_routes.py +447 -0
- llama_stack_api/{vector_io.py → vector_io/models.py} +99 -377
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
- llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
- llama_stack_api/agents.py +0 -173
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/eval.py +0 -137
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/safety.py +0 -132
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/shields.py +0 -93
- llama_stack_api-0.4.4.dist-info/RECORD +0 -70
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""FastAPI router for the Post-Training API.
|
|
8
|
+
|
|
9
|
+
This module defines the FastAPI router for the Post-Training API using standard
|
|
10
|
+
FastAPI route decorators.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Annotated
|
|
14
|
+
|
|
15
|
+
from fastapi import APIRouter, Body, Depends
|
|
16
|
+
|
|
17
|
+
from llama_stack_api.router_utils import create_path_dependency, standard_responses
|
|
18
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1ALPHA
|
|
19
|
+
|
|
20
|
+
from .api import PostTraining
|
|
21
|
+
from .models import (
|
|
22
|
+
CancelTrainingJobRequest,
|
|
23
|
+
GetTrainingJobArtifactsRequest,
|
|
24
|
+
GetTrainingJobStatusRequest,
|
|
25
|
+
ListPostTrainingJobsResponse,
|
|
26
|
+
PostTrainingJob,
|
|
27
|
+
PostTrainingJobArtifactsResponse,
|
|
28
|
+
PostTrainingJobStatusResponse,
|
|
29
|
+
PreferenceOptimizeRequest,
|
|
30
|
+
SupervisedFineTuneRequest,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Path parameter dependencies for single-field models
|
|
34
|
+
get_training_job_status_request = create_path_dependency(GetTrainingJobStatusRequest)
|
|
35
|
+
cancel_training_job_request = create_path_dependency(CancelTrainingJobRequest)
|
|
36
|
+
get_training_job_artifacts_request = create_path_dependency(GetTrainingJobArtifactsRequest)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def create_router(impl: PostTraining) -> APIRouter:
|
|
40
|
+
"""Create a FastAPI router for the Post-Training API."""
|
|
41
|
+
router = APIRouter(
|
|
42
|
+
prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
|
|
43
|
+
tags=["Post Training"],
|
|
44
|
+
responses=standard_responses,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@router.post(
|
|
48
|
+
"/post-training/supervised-fine-tune",
|
|
49
|
+
response_model=PostTrainingJob,
|
|
50
|
+
summary="Run supervised fine-tuning of a model.",
|
|
51
|
+
description="Run supervised fine-tuning of a model.",
|
|
52
|
+
responses={200: {"description": "A PostTrainingJob."}},
|
|
53
|
+
)
|
|
54
|
+
async def supervised_fine_tune(
|
|
55
|
+
request: Annotated[SupervisedFineTuneRequest, Body(...)],
|
|
56
|
+
) -> PostTrainingJob:
|
|
57
|
+
return await impl.supervised_fine_tune(request)
|
|
58
|
+
|
|
59
|
+
@router.post(
|
|
60
|
+
"/post-training/preference-optimize",
|
|
61
|
+
response_model=PostTrainingJob,
|
|
62
|
+
summary="Run preference optimization of a model.",
|
|
63
|
+
description="Run preference optimization of a model.",
|
|
64
|
+
responses={200: {"description": "A PostTrainingJob."}},
|
|
65
|
+
)
|
|
66
|
+
async def preference_optimize(
|
|
67
|
+
request: Annotated[PreferenceOptimizeRequest, Body(...)],
|
|
68
|
+
) -> PostTrainingJob:
|
|
69
|
+
return await impl.preference_optimize(request)
|
|
70
|
+
|
|
71
|
+
@router.get(
|
|
72
|
+
"/post-training/jobs",
|
|
73
|
+
response_model=ListPostTrainingJobsResponse,
|
|
74
|
+
summary="Get all training jobs.",
|
|
75
|
+
description="Get all training jobs.",
|
|
76
|
+
responses={200: {"description": "A ListPostTrainingJobsResponse."}},
|
|
77
|
+
)
|
|
78
|
+
async def get_training_jobs() -> ListPostTrainingJobsResponse:
|
|
79
|
+
return await impl.get_training_jobs()
|
|
80
|
+
|
|
81
|
+
@router.get(
|
|
82
|
+
"/post-training/job/status",
|
|
83
|
+
response_model=PostTrainingJobStatusResponse,
|
|
84
|
+
summary="Get the status of a training job.",
|
|
85
|
+
description="Get the status of a training job.",
|
|
86
|
+
responses={200: {"description": "A PostTrainingJobStatusResponse."}},
|
|
87
|
+
)
|
|
88
|
+
async def get_training_job_status(
|
|
89
|
+
request: Annotated[GetTrainingJobStatusRequest, Depends(get_training_job_status_request)],
|
|
90
|
+
) -> PostTrainingJobStatusResponse:
|
|
91
|
+
return await impl.get_training_job_status(request)
|
|
92
|
+
|
|
93
|
+
@router.post(
|
|
94
|
+
"/post-training/job/cancel",
|
|
95
|
+
summary="Cancel a training job.",
|
|
96
|
+
description="Cancel a training job.",
|
|
97
|
+
responses={200: {"description": "Successfully cancelled the training job."}},
|
|
98
|
+
)
|
|
99
|
+
async def cancel_training_job(
|
|
100
|
+
request: Annotated[CancelTrainingJobRequest, Depends(cancel_training_job_request)],
|
|
101
|
+
) -> None:
|
|
102
|
+
return await impl.cancel_training_job(request)
|
|
103
|
+
|
|
104
|
+
@router.get(
|
|
105
|
+
"/post-training/job/artifacts",
|
|
106
|
+
response_model=PostTrainingJobArtifactsResponse,
|
|
107
|
+
summary="Get the artifacts of a training job.",
|
|
108
|
+
description="Get the artifacts of a training job.",
|
|
109
|
+
responses={200: {"description": "A PostTrainingJobArtifactsResponse."}},
|
|
110
|
+
)
|
|
111
|
+
async def get_training_job_artifacts(
|
|
112
|
+
request: Annotated[GetTrainingJobArtifactsRequest, Depends(get_training_job_artifacts_request)],
|
|
113
|
+
) -> PostTrainingJobArtifactsResponse:
|
|
114
|
+
return await impl.get_training_job_artifacts(request)
|
|
115
|
+
|
|
116
|
+
return router
|
|
@@ -4,17 +4,22 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
+
"""Pydantic models for Post-Training API requests and responses.
|
|
8
|
+
|
|
9
|
+
This module defines the request and response models for the Post-Training API
|
|
10
|
+
using Pydantic with Field descriptions for OpenAPI schema generation.
|
|
11
|
+
"""
|
|
12
|
+
|
|
7
13
|
from datetime import datetime
|
|
8
14
|
from enum import Enum
|
|
9
|
-
from typing import Annotated, Any, Literal
|
|
15
|
+
from typing import Annotated, Any, Literal
|
|
10
16
|
|
|
11
17
|
from pydantic import BaseModel, Field
|
|
12
18
|
|
|
13
19
|
from llama_stack_api.common.content_types import URL
|
|
14
20
|
from llama_stack_api.common.job_types import JobStatus
|
|
15
21
|
from llama_stack_api.common.training_types import Checkpoint
|
|
16
|
-
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
|
17
|
-
from llama_stack_api.version import LLAMA_STACK_API_V1ALPHA
|
|
22
|
+
from llama_stack_api.schema_utils import json_schema_type, register_schema
|
|
18
23
|
|
|
19
24
|
|
|
20
25
|
@json_schema_type
|
|
@@ -285,86 +290,50 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|
|
285
290
|
# TODO(ashwin): metrics, evals
|
|
286
291
|
|
|
287
292
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
"""
|
|
336
|
-
...
|
|
337
|
-
|
|
338
|
-
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
339
|
-
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
|
340
|
-
"""Get all training jobs.
|
|
341
|
-
|
|
342
|
-
:returns: A ListPostTrainingJobsResponse.
|
|
343
|
-
"""
|
|
344
|
-
...
|
|
345
|
-
|
|
346
|
-
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
347
|
-
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
|
348
|
-
"""Get the status of a training job.
|
|
349
|
-
|
|
350
|
-
:param job_uuid: The UUID of the job to get the status of.
|
|
351
|
-
:returns: A PostTrainingJobStatusResponse.
|
|
352
|
-
"""
|
|
353
|
-
...
|
|
354
|
-
|
|
355
|
-
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
|
356
|
-
async def cancel_training_job(self, job_uuid: str) -> None:
|
|
357
|
-
"""Cancel a training job.
|
|
358
|
-
|
|
359
|
-
:param job_uuid: The UUID of the job to cancel.
|
|
360
|
-
"""
|
|
361
|
-
...
|
|
362
|
-
|
|
363
|
-
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
|
364
|
-
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
|
365
|
-
"""Get the artifacts of a training job.
|
|
366
|
-
|
|
367
|
-
:param job_uuid: The UUID of the job to get the artifacts of.
|
|
368
|
-
:returns: A PostTrainingJobArtifactsResponse.
|
|
369
|
-
"""
|
|
370
|
-
...
|
|
293
|
+
@json_schema_type
|
|
294
|
+
class SupervisedFineTuneRequest(BaseModel):
|
|
295
|
+
"""Request to run supervised fine-tuning of a model."""
|
|
296
|
+
|
|
297
|
+
job_uuid: str = Field(..., description="The UUID of the job to create.")
|
|
298
|
+
training_config: TrainingConfig = Field(..., description="The training configuration.")
|
|
299
|
+
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration.")
|
|
300
|
+
logger_config: dict[str, Any] = Field(..., description="The logger configuration.")
|
|
301
|
+
model: str | None = Field(
|
|
302
|
+
default=None,
|
|
303
|
+
description="Model descriptor for training if not in provider config",
|
|
304
|
+
)
|
|
305
|
+
checkpoint_dir: str | None = Field(default=None, description="The directory to save checkpoint(s) to.")
|
|
306
|
+
algorithm_config: AlgorithmConfig | None = Field(default=None, description="The algorithm configuration.")
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@json_schema_type
|
|
310
|
+
class PreferenceOptimizeRequest(BaseModel):
|
|
311
|
+
"""Request to run preference optimization of a model."""
|
|
312
|
+
|
|
313
|
+
job_uuid: str = Field(..., description="The UUID of the job to create.")
|
|
314
|
+
finetuned_model: str = Field(..., description="The model to fine-tune.")
|
|
315
|
+
algorithm_config: DPOAlignmentConfig = Field(..., description="The algorithm configuration.")
|
|
316
|
+
training_config: TrainingConfig = Field(..., description="The training configuration.")
|
|
317
|
+
hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration.")
|
|
318
|
+
logger_config: dict[str, Any] = Field(..., description="The logger configuration.")
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@json_schema_type
|
|
322
|
+
class GetTrainingJobStatusRequest(BaseModel):
|
|
323
|
+
"""Request to get the status of a training job."""
|
|
324
|
+
|
|
325
|
+
job_uuid: str = Field(..., description="The UUID of the job to get the status of.")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@json_schema_type
|
|
329
|
+
class CancelTrainingJobRequest(BaseModel):
|
|
330
|
+
"""Request to cancel a training job."""
|
|
331
|
+
|
|
332
|
+
job_uuid: str = Field(..., description="The UUID of the job to cancel.")
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@json_schema_type
|
|
336
|
+
class GetTrainingJobArtifactsRequest(BaseModel):
|
|
337
|
+
"""Request to get the artifacts of a training job."""
|
|
338
|
+
|
|
339
|
+
job_uuid: str = Field(..., description="The UUID of the job to get the artifacts of.")
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Prompts API protocol and models.
|
|
8
|
+
|
|
9
|
+
This module contains the Prompts protocol definition.
|
|
10
|
+
Pydantic models are defined in llama_stack_api.prompts.models.
|
|
11
|
+
The FastAPI router is defined in llama_stack_api.prompts.fastapi_routes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
# Import fastapi_routes for router factory access
|
|
15
|
+
from . import fastapi_routes
|
|
16
|
+
|
|
17
|
+
# Import protocol for FastAPI router
|
|
18
|
+
from .api import Prompts
|
|
19
|
+
|
|
20
|
+
# Import models for re-export
|
|
21
|
+
from .models import (
|
|
22
|
+
CreatePromptRequest,
|
|
23
|
+
DeletePromptRequest,
|
|
24
|
+
GetPromptRequest,
|
|
25
|
+
ListPromptsResponse,
|
|
26
|
+
ListPromptVersionsRequest,
|
|
27
|
+
Prompt,
|
|
28
|
+
SetDefaultVersionBodyRequest,
|
|
29
|
+
SetDefaultVersionRequest,
|
|
30
|
+
UpdatePromptBodyRequest,
|
|
31
|
+
UpdatePromptRequest,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"CreatePromptRequest",
|
|
36
|
+
"DeletePromptRequest",
|
|
37
|
+
"GetPromptRequest",
|
|
38
|
+
"ListPromptVersionsRequest",
|
|
39
|
+
"ListPromptsResponse",
|
|
40
|
+
"Prompt",
|
|
41
|
+
"Prompts",
|
|
42
|
+
"SetDefaultVersionBodyRequest",
|
|
43
|
+
"SetDefaultVersionRequest",
|
|
44
|
+
"UpdatePromptBodyRequest",
|
|
45
|
+
"UpdatePromptRequest",
|
|
46
|
+
"fastapi_routes",
|
|
47
|
+
]
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""Prompts API protocol definition.
|
|
8
|
+
|
|
9
|
+
This module contains the Prompts protocol definition.
|
|
10
|
+
Pydantic models are defined in llama_stack_api.prompts.models.
|
|
11
|
+
The FastAPI router is defined in llama_stack_api.prompts.fastapi_routes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import Protocol, runtime_checkable
|
|
15
|
+
|
|
16
|
+
from .models import (
|
|
17
|
+
CreatePromptRequest,
|
|
18
|
+
DeletePromptRequest,
|
|
19
|
+
GetPromptRequest,
|
|
20
|
+
ListPromptsResponse,
|
|
21
|
+
ListPromptVersionsRequest,
|
|
22
|
+
Prompt,
|
|
23
|
+
SetDefaultVersionRequest,
|
|
24
|
+
UpdatePromptRequest,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@runtime_checkable
|
|
29
|
+
class Prompts(Protocol):
|
|
30
|
+
"""Protocol for prompt management operations."""
|
|
31
|
+
|
|
32
|
+
async def list_prompts(self) -> ListPromptsResponse: ...
|
|
33
|
+
|
|
34
|
+
async def list_prompt_versions(self, request: ListPromptVersionsRequest) -> ListPromptsResponse: ...
|
|
35
|
+
|
|
36
|
+
async def get_prompt(self, request: GetPromptRequest) -> Prompt: ...
|
|
37
|
+
|
|
38
|
+
async def create_prompt(self, request: CreatePromptRequest) -> Prompt: ...
|
|
39
|
+
|
|
40
|
+
async def update_prompt(self, request: UpdatePromptRequest) -> Prompt: ...
|
|
41
|
+
|
|
42
|
+
async def delete_prompt(self, request: DeletePromptRequest) -> None: ...
|
|
43
|
+
|
|
44
|
+
async def set_default_version(self, request: SetDefaultVersionRequest) -> Prompt: ...
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
+
# the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""FastAPI router for the Prompts API.
|
|
8
|
+
|
|
9
|
+
This module defines the FastAPI router for the Prompts API using standard
|
|
10
|
+
FastAPI route decorators.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Annotated
|
|
14
|
+
|
|
15
|
+
from fastapi import APIRouter, Body, Depends, Path, Query
|
|
16
|
+
|
|
17
|
+
from llama_stack_api.router_utils import create_path_dependency, standard_responses
|
|
18
|
+
from llama_stack_api.version import LLAMA_STACK_API_V1
|
|
19
|
+
|
|
20
|
+
from .api import Prompts
|
|
21
|
+
from .models import (
|
|
22
|
+
CreatePromptRequest,
|
|
23
|
+
DeletePromptRequest,
|
|
24
|
+
GetPromptRequest,
|
|
25
|
+
ListPromptsResponse,
|
|
26
|
+
ListPromptVersionsRequest,
|
|
27
|
+
Prompt,
|
|
28
|
+
SetDefaultVersionBodyRequest,
|
|
29
|
+
SetDefaultVersionRequest,
|
|
30
|
+
UpdatePromptBodyRequest,
|
|
31
|
+
UpdatePromptRequest,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Path parameter dependencies for single-field models
|
|
35
|
+
list_prompt_versions_request = create_path_dependency(ListPromptVersionsRequest)
|
|
36
|
+
delete_prompt_request = create_path_dependency(DeletePromptRequest)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def create_router(impl: Prompts) -> APIRouter:
|
|
40
|
+
"""Create a FastAPI router for the Prompts API.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
impl: The Prompts implementation instance
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
APIRouter configured for the Prompts API
|
|
47
|
+
"""
|
|
48
|
+
router = APIRouter(
|
|
49
|
+
prefix=f"/{LLAMA_STACK_API_V1}",
|
|
50
|
+
tags=["Prompts"],
|
|
51
|
+
responses=standard_responses,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
@router.get(
|
|
55
|
+
"/prompts",
|
|
56
|
+
response_model=ListPromptsResponse,
|
|
57
|
+
summary="List all prompts.",
|
|
58
|
+
description="List all prompts.",
|
|
59
|
+
responses={
|
|
60
|
+
200: {"description": "A ListPromptsResponse containing all prompts."},
|
|
61
|
+
},
|
|
62
|
+
)
|
|
63
|
+
async def list_prompts() -> ListPromptsResponse:
|
|
64
|
+
return await impl.list_prompts()
|
|
65
|
+
|
|
66
|
+
@router.get(
|
|
67
|
+
"/prompts/{prompt_id}/versions",
|
|
68
|
+
response_model=ListPromptsResponse,
|
|
69
|
+
summary="List prompt versions.",
|
|
70
|
+
description="List all versions of a specific prompt.",
|
|
71
|
+
responses={
|
|
72
|
+
200: {"description": "A ListPromptsResponse containing all versions of the prompt."},
|
|
73
|
+
},
|
|
74
|
+
)
|
|
75
|
+
async def list_prompt_versions(
|
|
76
|
+
request: Annotated[ListPromptVersionsRequest, Depends(list_prompt_versions_request)],
|
|
77
|
+
) -> ListPromptsResponse:
|
|
78
|
+
return await impl.list_prompt_versions(request)
|
|
79
|
+
|
|
80
|
+
@router.get(
|
|
81
|
+
"/prompts/{prompt_id}",
|
|
82
|
+
response_model=Prompt,
|
|
83
|
+
summary="Get a prompt.",
|
|
84
|
+
description="Get a prompt by its identifier and optional version.",
|
|
85
|
+
responses={
|
|
86
|
+
200: {"description": "A Prompt resource."},
|
|
87
|
+
},
|
|
88
|
+
)
|
|
89
|
+
async def get_prompt(
|
|
90
|
+
prompt_id: Annotated[str, Path(description="The identifier of the prompt to get.")],
|
|
91
|
+
version: Annotated[
|
|
92
|
+
int | None, Query(description="The version of the prompt to get (defaults to latest).")
|
|
93
|
+
] = None,
|
|
94
|
+
) -> Prompt:
|
|
95
|
+
request = GetPromptRequest(prompt_id=prompt_id, version=version)
|
|
96
|
+
return await impl.get_prompt(request)
|
|
97
|
+
|
|
98
|
+
@router.post(
|
|
99
|
+
"/prompts",
|
|
100
|
+
response_model=Prompt,
|
|
101
|
+
summary="Create a prompt.",
|
|
102
|
+
description="Create a new prompt.",
|
|
103
|
+
responses={
|
|
104
|
+
200: {"description": "The created Prompt resource."},
|
|
105
|
+
},
|
|
106
|
+
)
|
|
107
|
+
async def create_prompt(
|
|
108
|
+
request: Annotated[CreatePromptRequest, Body(...)],
|
|
109
|
+
) -> Prompt:
|
|
110
|
+
return await impl.create_prompt(request)
|
|
111
|
+
|
|
112
|
+
@router.put(
|
|
113
|
+
"/prompts/{prompt_id}",
|
|
114
|
+
response_model=Prompt,
|
|
115
|
+
summary="Update a prompt.",
|
|
116
|
+
description="Update an existing prompt (increments version).",
|
|
117
|
+
responses={
|
|
118
|
+
200: {"description": "The updated Prompt resource with incremented version."},
|
|
119
|
+
},
|
|
120
|
+
)
|
|
121
|
+
async def update_prompt(
|
|
122
|
+
prompt_id: Annotated[str, Path(description="The identifier of the prompt to update.")],
|
|
123
|
+
body: Annotated[UpdatePromptBodyRequest, Body(...)],
|
|
124
|
+
) -> Prompt:
|
|
125
|
+
request = UpdatePromptRequest(
|
|
126
|
+
prompt_id=prompt_id,
|
|
127
|
+
prompt=body.prompt,
|
|
128
|
+
version=body.version,
|
|
129
|
+
variables=body.variables,
|
|
130
|
+
set_as_default=body.set_as_default,
|
|
131
|
+
)
|
|
132
|
+
return await impl.update_prompt(request)
|
|
133
|
+
|
|
134
|
+
@router.delete(
|
|
135
|
+
"/prompts/{prompt_id}",
|
|
136
|
+
summary="Delete a prompt.",
|
|
137
|
+
description="Delete a prompt.",
|
|
138
|
+
responses={
|
|
139
|
+
200: {"description": "The prompt was successfully deleted."},
|
|
140
|
+
},
|
|
141
|
+
)
|
|
142
|
+
async def delete_prompt(
|
|
143
|
+
request: Annotated[DeletePromptRequest, Depends(delete_prompt_request)],
|
|
144
|
+
) -> None:
|
|
145
|
+
return await impl.delete_prompt(request)
|
|
146
|
+
|
|
147
|
+
@router.put(
|
|
148
|
+
"/prompts/{prompt_id}/set-default-version",
|
|
149
|
+
response_model=Prompt,
|
|
150
|
+
summary="Set prompt version.",
|
|
151
|
+
description="Set which version of a prompt should be the default in get_prompt (latest).",
|
|
152
|
+
responses={
|
|
153
|
+
200: {"description": "The prompt with the specified version now set as default."},
|
|
154
|
+
},
|
|
155
|
+
)
|
|
156
|
+
async def set_default_version(
|
|
157
|
+
prompt_id: Annotated[str, Path(description="The identifier of the prompt.")],
|
|
158
|
+
body: Annotated[SetDefaultVersionBodyRequest, Body(...)],
|
|
159
|
+
) -> Prompt:
|
|
160
|
+
request = SetDefaultVersionRequest(prompt_id=prompt_id, version=body.version)
|
|
161
|
+
return await impl.set_default_version(request)
|
|
162
|
+
|
|
163
|
+
return router
|