llama-stack-api 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. llama_stack_api/__init__.py +1100 -0
  2. llama_stack_api/admin/__init__.py +45 -0
  3. llama_stack_api/admin/api.py +72 -0
  4. llama_stack_api/admin/fastapi_routes.py +117 -0
  5. llama_stack_api/admin/models.py +113 -0
  6. llama_stack_api/agents/__init__.py +38 -0
  7. llama_stack_api/agents/api.py +52 -0
  8. llama_stack_api/agents/fastapi_routes.py +268 -0
  9. llama_stack_api/agents/models.py +181 -0
  10. llama_stack_api/batches/__init__.py +40 -0
  11. llama_stack_api/batches/api.py +53 -0
  12. llama_stack_api/batches/fastapi_routes.py +113 -0
  13. llama_stack_api/batches/models.py +78 -0
  14. llama_stack_api/benchmarks/__init__.py +43 -0
  15. llama_stack_api/benchmarks/api.py +39 -0
  16. llama_stack_api/benchmarks/fastapi_routes.py +109 -0
  17. llama_stack_api/benchmarks/models.py +109 -0
  18. llama_stack_api/common/__init__.py +5 -0
  19. llama_stack_api/common/content_types.py +101 -0
  20. llama_stack_api/common/errors.py +110 -0
  21. llama_stack_api/common/job_types.py +38 -0
  22. llama_stack_api/common/responses.py +77 -0
  23. llama_stack_api/common/training_types.py +47 -0
  24. llama_stack_api/common/type_system.py +146 -0
  25. llama_stack_api/connectors/__init__.py +38 -0
  26. llama_stack_api/connectors/api.py +50 -0
  27. llama_stack_api/connectors/fastapi_routes.py +103 -0
  28. llama_stack_api/connectors/models.py +103 -0
  29. llama_stack_api/conversations/__init__.py +61 -0
  30. llama_stack_api/conversations/api.py +44 -0
  31. llama_stack_api/conversations/fastapi_routes.py +177 -0
  32. llama_stack_api/conversations/models.py +245 -0
  33. llama_stack_api/datasetio/__init__.py +34 -0
  34. llama_stack_api/datasetio/api.py +42 -0
  35. llama_stack_api/datasetio/fastapi_routes.py +94 -0
  36. llama_stack_api/datasetio/models.py +48 -0
  37. llama_stack_api/datasets/__init__.py +61 -0
  38. llama_stack_api/datasets/api.py +35 -0
  39. llama_stack_api/datasets/fastapi_routes.py +104 -0
  40. llama_stack_api/datasets/models.py +152 -0
  41. llama_stack_api/datatypes.py +373 -0
  42. llama_stack_api/eval/__init__.py +55 -0
  43. llama_stack_api/eval/api.py +51 -0
  44. llama_stack_api/eval/compat.py +300 -0
  45. llama_stack_api/eval/fastapi_routes.py +126 -0
  46. llama_stack_api/eval/models.py +141 -0
  47. llama_stack_api/file_processors/__init__.py +27 -0
  48. llama_stack_api/file_processors/api.py +64 -0
  49. llama_stack_api/file_processors/fastapi_routes.py +78 -0
  50. llama_stack_api/file_processors/models.py +42 -0
  51. llama_stack_api/files/__init__.py +35 -0
  52. llama_stack_api/files/api.py +51 -0
  53. llama_stack_api/files/fastapi_routes.py +124 -0
  54. llama_stack_api/files/models.py +107 -0
  55. llama_stack_api/inference/__init__.py +207 -0
  56. llama_stack_api/inference/api.py +93 -0
  57. llama_stack_api/inference/fastapi_routes.py +243 -0
  58. llama_stack_api/inference/models.py +1035 -0
  59. llama_stack_api/inspect_api/__init__.py +37 -0
  60. llama_stack_api/inspect_api/api.py +25 -0
  61. llama_stack_api/inspect_api/fastapi_routes.py +76 -0
  62. llama_stack_api/inspect_api/models.py +28 -0
  63. llama_stack_api/internal/__init__.py +9 -0
  64. llama_stack_api/internal/kvstore.py +28 -0
  65. llama_stack_api/internal/sqlstore.py +81 -0
  66. llama_stack_api/models/__init__.py +47 -0
  67. llama_stack_api/models/api.py +38 -0
  68. llama_stack_api/models/fastapi_routes.py +104 -0
  69. llama_stack_api/models/models.py +157 -0
  70. llama_stack_api/openai_responses.py +1494 -0
  71. llama_stack_api/post_training/__init__.py +73 -0
  72. llama_stack_api/post_training/api.py +36 -0
  73. llama_stack_api/post_training/fastapi_routes.py +116 -0
  74. llama_stack_api/post_training/models.py +339 -0
  75. llama_stack_api/prompts/__init__.py +47 -0
  76. llama_stack_api/prompts/api.py +44 -0
  77. llama_stack_api/prompts/fastapi_routes.py +163 -0
  78. llama_stack_api/prompts/models.py +177 -0
  79. llama_stack_api/providers/__init__.py +33 -0
  80. llama_stack_api/providers/api.py +16 -0
  81. llama_stack_api/providers/fastapi_routes.py +57 -0
  82. llama_stack_api/providers/models.py +24 -0
  83. llama_stack_api/rag_tool.py +168 -0
  84. llama_stack_api/resource.py +36 -0
  85. llama_stack_api/router_utils.py +160 -0
  86. llama_stack_api/safety/__init__.py +37 -0
  87. llama_stack_api/safety/api.py +29 -0
  88. llama_stack_api/safety/datatypes.py +83 -0
  89. llama_stack_api/safety/fastapi_routes.py +55 -0
  90. llama_stack_api/safety/models.py +38 -0
  91. llama_stack_api/schema_utils.py +251 -0
  92. llama_stack_api/scoring/__init__.py +66 -0
  93. llama_stack_api/scoring/api.py +35 -0
  94. llama_stack_api/scoring/fastapi_routes.py +67 -0
  95. llama_stack_api/scoring/models.py +81 -0
  96. llama_stack_api/scoring_functions/__init__.py +50 -0
  97. llama_stack_api/scoring_functions/api.py +39 -0
  98. llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
  99. llama_stack_api/scoring_functions/models.py +214 -0
  100. llama_stack_api/shields/__init__.py +41 -0
  101. llama_stack_api/shields/api.py +39 -0
  102. llama_stack_api/shields/fastapi_routes.py +104 -0
  103. llama_stack_api/shields/models.py +74 -0
  104. llama_stack_api/tools.py +226 -0
  105. llama_stack_api/validators.py +46 -0
  106. llama_stack_api/vector_io/__init__.py +88 -0
  107. llama_stack_api/vector_io/api.py +234 -0
  108. llama_stack_api/vector_io/fastapi_routes.py +447 -0
  109. llama_stack_api/vector_io/models.py +663 -0
  110. llama_stack_api/vector_stores.py +53 -0
  111. llama_stack_api/version.py +9 -0
  112. {llama_stack_api-0.4.3.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
  113. llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
  114. llama_stack_api-0.5.0rc1.dist-info/top_level.txt +1 -0
  115. llama_stack_api-0.4.3.dist-info/RECORD +0 -4
  116. llama_stack_api-0.4.3.dist-info/top_level.txt +0 -1
  117. {llama_stack_api-0.4.3.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,73 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Post-Training API protocol and models.
8
+
9
+ This module contains the Post-Training protocol definition.
10
+ Pydantic models are defined in llama_stack_api.post_training.models.
11
+ The FastAPI router is defined in llama_stack_api.post_training.fastapi_routes.
12
+ """
13
+
14
+ # Import fastapi_routes for router factory access
15
+ from . import fastapi_routes
16
+
17
+ # Import protocol for re-export
18
+ from .api import PostTraining
19
+
20
+ # Import models for re-export
21
+ from .models import (
22
+ AlgorithmConfig,
23
+ CancelTrainingJobRequest,
24
+ DataConfig,
25
+ DatasetFormat,
26
+ DPOAlignmentConfig,
27
+ DPOLossType,
28
+ EfficiencyConfig,
29
+ GetTrainingJobArtifactsRequest,
30
+ GetTrainingJobStatusRequest,
31
+ ListPostTrainingJobsResponse,
32
+ LoraFinetuningConfig,
33
+ OptimizerConfig,
34
+ OptimizerType,
35
+ PostTrainingJob,
36
+ PostTrainingJobArtifactsResponse,
37
+ PostTrainingJobLogStream,
38
+ PostTrainingJobStatusResponse,
39
+ PostTrainingRLHFRequest,
40
+ PreferenceOptimizeRequest,
41
+ QATFinetuningConfig,
42
+ RLHFAlgorithm,
43
+ SupervisedFineTuneRequest,
44
+ TrainingConfig,
45
+ )
46
+
47
+ __all__ = [
48
+ "PostTraining",
49
+ "AlgorithmConfig",
50
+ "CancelTrainingJobRequest",
51
+ "DataConfig",
52
+ "DatasetFormat",
53
+ "DPOAlignmentConfig",
54
+ "DPOLossType",
55
+ "EfficiencyConfig",
56
+ "GetTrainingJobArtifactsRequest",
57
+ "GetTrainingJobStatusRequest",
58
+ "ListPostTrainingJobsResponse",
59
+ "LoraFinetuningConfig",
60
+ "OptimizerConfig",
61
+ "OptimizerType",
62
+ "PostTrainingJob",
63
+ "PostTrainingJobArtifactsResponse",
64
+ "PostTrainingJobLogStream",
65
+ "PostTrainingJobStatusResponse",
66
+ "PostTrainingRLHFRequest",
67
+ "PreferenceOptimizeRequest",
68
+ "QATFinetuningConfig",
69
+ "RLHFAlgorithm",
70
+ "SupervisedFineTuneRequest",
71
+ "TrainingConfig",
72
+ "fastapi_routes",
73
+ ]
@@ -0,0 +1,36 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Protocol, runtime_checkable
8
+
9
+ from .models import (
10
+ CancelTrainingJobRequest,
11
+ GetTrainingJobArtifactsRequest,
12
+ GetTrainingJobStatusRequest,
13
+ ListPostTrainingJobsResponse,
14
+ PostTrainingJob,
15
+ PostTrainingJobArtifactsResponse,
16
+ PostTrainingJobStatusResponse,
17
+ PreferenceOptimizeRequest,
18
+ SupervisedFineTuneRequest,
19
+ )
20
+
21
+
22
+ @runtime_checkable
23
+ class PostTraining(Protocol):
24
+ async def supervised_fine_tune(self, request: SupervisedFineTuneRequest) -> PostTrainingJob: ...
25
+
26
+ async def preference_optimize(self, request: PreferenceOptimizeRequest) -> PostTrainingJob: ...
27
+
28
+ async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
29
+
30
+ async def get_training_job_status(self, request: GetTrainingJobStatusRequest) -> PostTrainingJobStatusResponse: ...
31
+
32
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None: ...
33
+
34
+ async def get_training_job_artifacts(
35
+ self, request: GetTrainingJobArtifactsRequest
36
+ ) -> PostTrainingJobArtifactsResponse: ...
@@ -0,0 +1,116 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """FastAPI router for the Post-Training API.
8
+
9
+ This module defines the FastAPI router for the Post-Training API using standard
10
+ FastAPI route decorators.
11
+ """
12
+
13
+ from typing import Annotated
14
+
15
+ from fastapi import APIRouter, Body, Depends
16
+
17
+ from llama_stack_api.router_utils import create_path_dependency, standard_responses
18
+ from llama_stack_api.version import LLAMA_STACK_API_V1ALPHA
19
+
20
+ from .api import PostTraining
21
+ from .models import (
22
+ CancelTrainingJobRequest,
23
+ GetTrainingJobArtifactsRequest,
24
+ GetTrainingJobStatusRequest,
25
+ ListPostTrainingJobsResponse,
26
+ PostTrainingJob,
27
+ PostTrainingJobArtifactsResponse,
28
+ PostTrainingJobStatusResponse,
29
+ PreferenceOptimizeRequest,
30
+ SupervisedFineTuneRequest,
31
+ )
32
+
33
+ # Path parameter dependencies for single-field models
34
+ get_training_job_status_request = create_path_dependency(GetTrainingJobStatusRequest)
35
+ cancel_training_job_request = create_path_dependency(CancelTrainingJobRequest)
36
+ get_training_job_artifacts_request = create_path_dependency(GetTrainingJobArtifactsRequest)
37
+
38
+
39
+ def create_router(impl: PostTraining) -> APIRouter:
40
+ """Create a FastAPI router for the Post-Training API."""
41
+ router = APIRouter(
42
+ prefix=f"/{LLAMA_STACK_API_V1ALPHA}",
43
+ tags=["Post Training"],
44
+ responses=standard_responses,
45
+ )
46
+
47
+ @router.post(
48
+ "/post-training/supervised-fine-tune",
49
+ response_model=PostTrainingJob,
50
+ summary="Run supervised fine-tuning of a model.",
51
+ description="Run supervised fine-tuning of a model.",
52
+ responses={200: {"description": "A PostTrainingJob."}},
53
+ )
54
+ async def supervised_fine_tune(
55
+ request: Annotated[SupervisedFineTuneRequest, Body(...)],
56
+ ) -> PostTrainingJob:
57
+ return await impl.supervised_fine_tune(request)
58
+
59
+ @router.post(
60
+ "/post-training/preference-optimize",
61
+ response_model=PostTrainingJob,
62
+ summary="Run preference optimization of a model.",
63
+ description="Run preference optimization of a model.",
64
+ responses={200: {"description": "A PostTrainingJob."}},
65
+ )
66
+ async def preference_optimize(
67
+ request: Annotated[PreferenceOptimizeRequest, Body(...)],
68
+ ) -> PostTrainingJob:
69
+ return await impl.preference_optimize(request)
70
+
71
+ @router.get(
72
+ "/post-training/jobs",
73
+ response_model=ListPostTrainingJobsResponse,
74
+ summary="Get all training jobs.",
75
+ description="Get all training jobs.",
76
+ responses={200: {"description": "A ListPostTrainingJobsResponse."}},
77
+ )
78
+ async def get_training_jobs() -> ListPostTrainingJobsResponse:
79
+ return await impl.get_training_jobs()
80
+
81
+ @router.get(
82
+ "/post-training/job/status",
83
+ response_model=PostTrainingJobStatusResponse,
84
+ summary="Get the status of a training job.",
85
+ description="Get the status of a training job.",
86
+ responses={200: {"description": "A PostTrainingJobStatusResponse."}},
87
+ )
88
+ async def get_training_job_status(
89
+ request: Annotated[GetTrainingJobStatusRequest, Depends(get_training_job_status_request)],
90
+ ) -> PostTrainingJobStatusResponse:
91
+ return await impl.get_training_job_status(request)
92
+
93
+ @router.post(
94
+ "/post-training/job/cancel",
95
+ summary="Cancel a training job.",
96
+ description="Cancel a training job.",
97
+ responses={200: {"description": "Successfully cancelled the training job."}},
98
+ )
99
+ async def cancel_training_job(
100
+ request: Annotated[CancelTrainingJobRequest, Depends(cancel_training_job_request)],
101
+ ) -> None:
102
+ return await impl.cancel_training_job(request)
103
+
104
+ @router.get(
105
+ "/post-training/job/artifacts",
106
+ response_model=PostTrainingJobArtifactsResponse,
107
+ summary="Get the artifacts of a training job.",
108
+ description="Get the artifacts of a training job.",
109
+ responses={200: {"description": "A PostTrainingJobArtifactsResponse."}},
110
+ )
111
+ async def get_training_job_artifacts(
112
+ request: Annotated[GetTrainingJobArtifactsRequest, Depends(get_training_job_artifacts_request)],
113
+ ) -> PostTrainingJobArtifactsResponse:
114
+ return await impl.get_training_job_artifacts(request)
115
+
116
+ return router
@@ -0,0 +1,339 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Pydantic models for Post-Training API requests and responses.
8
+
9
+ This module defines the request and response models for the Post-Training API
10
+ using Pydantic with Field descriptions for OpenAPI schema generation.
11
+ """
12
+
13
+ from datetime import datetime
14
+ from enum import Enum
15
+ from typing import Annotated, Any, Literal
16
+
17
+ from pydantic import BaseModel, Field
18
+
19
+ from llama_stack_api.common.content_types import URL
20
+ from llama_stack_api.common.job_types import JobStatus
21
+ from llama_stack_api.common.training_types import Checkpoint
22
+ from llama_stack_api.schema_utils import json_schema_type, register_schema
23
+
24
+
25
+ @json_schema_type
26
+ class OptimizerType(Enum):
27
+ """Available optimizer algorithms for training.
28
+ :cvar adam: Adaptive Moment Estimation optimizer
29
+ :cvar adamw: AdamW optimizer with weight decay
30
+ :cvar sgd: Stochastic Gradient Descent optimizer
31
+ """
32
+
33
+ adam = "adam"
34
+ adamw = "adamw"
35
+ sgd = "sgd"
36
+
37
+
38
+ @json_schema_type
39
+ class DatasetFormat(Enum):
40
+ """Format of the training dataset.
41
+ :cvar instruct: Instruction-following format with prompt and completion
42
+ :cvar dialog: Multi-turn conversation format with messages
43
+ """
44
+
45
+ instruct = "instruct"
46
+ dialog = "dialog"
47
+
48
+
49
+ @json_schema_type
50
+ class DataConfig(BaseModel):
51
+ """Configuration for training data and data loading.
52
+
53
+ :param dataset_id: Unique identifier for the training dataset
54
+ :param batch_size: Number of samples per training batch
55
+ :param shuffle: Whether to shuffle the dataset during training
56
+ :param data_format: Format of the dataset (instruct or dialog)
57
+ :param validation_dataset_id: (Optional) Unique identifier for the validation dataset
58
+ :param packed: (Optional) Whether to pack multiple samples into a single sequence for efficiency
59
+ :param train_on_input: (Optional) Whether to compute loss on input tokens as well as output tokens
60
+ """
61
+
62
+ dataset_id: str
63
+ batch_size: int
64
+ shuffle: bool
65
+ data_format: DatasetFormat
66
+ validation_dataset_id: str | None = None
67
+ packed: bool | None = False
68
+ train_on_input: bool | None = False
69
+
70
+
71
+ @json_schema_type
72
+ class OptimizerConfig(BaseModel):
73
+ """Configuration parameters for the optimization algorithm.
74
+
75
+ :param optimizer_type: Type of optimizer to use (adam, adamw, or sgd)
76
+ :param lr: Learning rate for the optimizer
77
+ :param weight_decay: Weight decay coefficient for regularization
78
+ :param num_warmup_steps: Number of steps for learning rate warmup
79
+ """
80
+
81
+ optimizer_type: OptimizerType
82
+ lr: float
83
+ weight_decay: float
84
+ num_warmup_steps: int
85
+
86
+
87
+ @json_schema_type
88
+ class EfficiencyConfig(BaseModel):
89
+ """Configuration for memory and compute efficiency optimizations.
90
+
91
+ :param enable_activation_checkpointing: (Optional) Whether to use activation checkpointing to reduce memory usage
92
+ :param enable_activation_offloading: (Optional) Whether to offload activations to CPU to save GPU memory
93
+ :param memory_efficient_fsdp_wrap: (Optional) Whether to use memory-efficient FSDP wrapping
94
+ :param fsdp_cpu_offload: (Optional) Whether to offload FSDP parameters to CPU
95
+ """
96
+
97
+ enable_activation_checkpointing: bool | None = False
98
+ enable_activation_offloading: bool | None = False
99
+ memory_efficient_fsdp_wrap: bool | None = False
100
+ fsdp_cpu_offload: bool | None = False
101
+
102
+
103
+ @json_schema_type
104
+ class TrainingConfig(BaseModel):
105
+ """Comprehensive configuration for the training process.
106
+
107
+ :param n_epochs: Number of training epochs to run
108
+ :param max_steps_per_epoch: Maximum number of steps to run per epoch
109
+ :param gradient_accumulation_steps: Number of steps to accumulate gradients before updating
110
+ :param max_validation_steps: (Optional) Maximum number of validation steps per epoch
111
+ :param data_config: (Optional) Configuration for data loading and formatting
112
+ :param optimizer_config: (Optional) Configuration for the optimization algorithm
113
+ :param efficiency_config: (Optional) Configuration for memory and compute optimizations
114
+ :param dtype: (Optional) Data type for model parameters (bf16, fp16, fp32)
115
+ """
116
+
117
+ n_epochs: int
118
+ max_steps_per_epoch: int = 1
119
+ gradient_accumulation_steps: int = 1
120
+ max_validation_steps: int | None = 1
121
+ data_config: DataConfig | None = None
122
+ optimizer_config: OptimizerConfig | None = None
123
+ efficiency_config: EfficiencyConfig | None = None
124
+ dtype: str | None = "bf16"
125
+
126
+
127
+ @json_schema_type
128
+ class LoraFinetuningConfig(BaseModel):
129
+ """Configuration for Low-Rank Adaptation (LoRA) fine-tuning.
130
+
131
+ :param type: Algorithm type identifier, always "LoRA"
132
+ :param lora_attn_modules: List of attention module names to apply LoRA to
133
+ :param apply_lora_to_mlp: Whether to apply LoRA to MLP layers
134
+ :param apply_lora_to_output: Whether to apply LoRA to output projection layers
135
+ :param rank: Rank of the LoRA adaptation (lower rank = fewer parameters)
136
+ :param alpha: LoRA scaling parameter that controls adaptation strength
137
+ :param use_dora: (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation)
138
+ :param quantize_base: (Optional) Whether to quantize the base model weights
139
+ """
140
+
141
+ type: Literal["LoRA"] = "LoRA"
142
+ lora_attn_modules: list[str]
143
+ apply_lora_to_mlp: bool
144
+ apply_lora_to_output: bool
145
+ rank: int
146
+ alpha: int
147
+ use_dora: bool | None = False
148
+ quantize_base: bool | None = False
149
+
150
+
151
+ @json_schema_type
152
+ class QATFinetuningConfig(BaseModel):
153
+ """Configuration for Quantization-Aware Training (QAT) fine-tuning.
154
+
155
+ :param type: Algorithm type identifier, always "QAT"
156
+ :param quantizer_name: Name of the quantization algorithm to use
157
+ :param group_size: Size of groups for grouped quantization
158
+ """
159
+
160
+ type: Literal["QAT"] = "QAT"
161
+ quantizer_name: str
162
+ group_size: int
163
+
164
+
165
+ AlgorithmConfig = Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")]
166
+ register_schema(AlgorithmConfig, name="AlgorithmConfig")
167
+
168
+
169
+ @json_schema_type
170
+ class PostTrainingJobLogStream(BaseModel):
171
+ """Stream of logs from a finetuning job.
172
+
173
+ :param job_uuid: Unique identifier for the training job
174
+ :param log_lines: List of log message strings from the training process
175
+ """
176
+
177
+ job_uuid: str
178
+ log_lines: list[str]
179
+
180
+
181
+ @json_schema_type
182
+ class RLHFAlgorithm(Enum):
183
+ """Available reinforcement learning from human feedback algorithms.
184
+ :cvar dpo: Direct Preference Optimization algorithm
185
+ """
186
+
187
+ dpo = "dpo"
188
+
189
+
190
+ @json_schema_type
191
+ class DPOLossType(Enum):
192
+ sigmoid = "sigmoid"
193
+ hinge = "hinge"
194
+ ipo = "ipo"
195
+ kto_pair = "kto_pair"
196
+
197
+
198
+ @json_schema_type
199
+ class DPOAlignmentConfig(BaseModel):
200
+ """Configuration for Direct Preference Optimization (DPO) alignment.
201
+
202
+ :param beta: Temperature parameter for the DPO loss
203
+ :param loss_type: The type of loss function to use for DPO
204
+ """
205
+
206
+ beta: float
207
+ loss_type: DPOLossType = DPOLossType.sigmoid
208
+
209
+
210
+ @json_schema_type
211
+ class PostTrainingRLHFRequest(BaseModel):
212
+ """Request to finetune a model using reinforcement learning from human feedback.
213
+
214
+ :param job_uuid: Unique identifier for the training job
215
+ :param finetuned_model: URL or path to the base model to fine-tune
216
+ :param dataset_id: Unique identifier for the training dataset
217
+ :param validation_dataset_id: Unique identifier for the validation dataset
218
+ :param algorithm: RLHF algorithm to use for training
219
+ :param algorithm_config: Configuration parameters for the RLHF algorithm
220
+ :param optimizer_config: Configuration parameters for the optimization algorithm
221
+ :param training_config: Configuration parameters for the training process
222
+ :param hyperparam_search_config: Configuration for hyperparameter search
223
+ :param logger_config: Configuration for training logging
224
+ """
225
+
226
+ job_uuid: str
227
+
228
+ finetuned_model: URL
229
+
230
+ dataset_id: str
231
+ validation_dataset_id: str
232
+
233
+ algorithm: RLHFAlgorithm
234
+ algorithm_config: DPOAlignmentConfig
235
+
236
+ optimizer_config: OptimizerConfig
237
+ training_config: TrainingConfig
238
+
239
+ # TODO: define these
240
+ hyperparam_search_config: dict[str, Any]
241
+ logger_config: dict[str, Any]
242
+
243
+
244
+ @json_schema_type
245
+ class PostTrainingJob(BaseModel):
246
+ job_uuid: str
247
+
248
+
249
+ @json_schema_type
250
+ class PostTrainingJobStatusResponse(BaseModel):
251
+ """Status of a finetuning job.
252
+
253
+ :param job_uuid: Unique identifier for the training job
254
+ :param status: Current status of the training job
255
+ :param scheduled_at: (Optional) Timestamp when the job was scheduled
256
+ :param started_at: (Optional) Timestamp when the job execution began
257
+ :param completed_at: (Optional) Timestamp when the job finished, if completed
258
+ :param resources_allocated: (Optional) Information about computational resources allocated to the job
259
+ :param checkpoints: List of model checkpoints created during training
260
+ """
261
+
262
+ job_uuid: str
263
+ status: JobStatus
264
+
265
+ scheduled_at: datetime | None = None
266
+ started_at: datetime | None = None
267
+ completed_at: datetime | None = None
268
+
269
+ resources_allocated: dict[str, Any] | None = None
270
+
271
+ checkpoints: list[Checkpoint] = Field(default_factory=list)
272
+
273
+
274
+ @json_schema_type
275
+ class ListPostTrainingJobsResponse(BaseModel):
276
+ data: list[PostTrainingJob]
277
+
278
+
279
+ @json_schema_type
280
+ class PostTrainingJobArtifactsResponse(BaseModel):
281
+ """Artifacts of a finetuning job.
282
+
283
+ :param job_uuid: Unique identifier for the training job
284
+ :param checkpoints: List of model checkpoints created during training
285
+ """
286
+
287
+ job_uuid: str
288
+ checkpoints: list[Checkpoint] = Field(default_factory=list)
289
+
290
+ # TODO(ashwin): metrics, evals
291
+
292
+
293
+ @json_schema_type
294
+ class SupervisedFineTuneRequest(BaseModel):
295
+ """Request to run supervised fine-tuning of a model."""
296
+
297
+ job_uuid: str = Field(..., description="The UUID of the job to create.")
298
+ training_config: TrainingConfig = Field(..., description="The training configuration.")
299
+ hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration.")
300
+ logger_config: dict[str, Any] = Field(..., description="The logger configuration.")
301
+ model: str | None = Field(
302
+ default=None,
303
+ description="Model descriptor for training if not in provider config",
304
+ )
305
+ checkpoint_dir: str | None = Field(default=None, description="The directory to save checkpoint(s) to.")
306
+ algorithm_config: AlgorithmConfig | None = Field(default=None, description="The algorithm configuration.")
307
+
308
+
309
+ @json_schema_type
310
+ class PreferenceOptimizeRequest(BaseModel):
311
+ """Request to run preference optimization of a model."""
312
+
313
+ job_uuid: str = Field(..., description="The UUID of the job to create.")
314
+ finetuned_model: str = Field(..., description="The model to fine-tune.")
315
+ algorithm_config: DPOAlignmentConfig = Field(..., description="The algorithm configuration.")
316
+ training_config: TrainingConfig = Field(..., description="The training configuration.")
317
+ hyperparam_search_config: dict[str, Any] = Field(..., description="The hyperparam search configuration.")
318
+ logger_config: dict[str, Any] = Field(..., description="The logger configuration.")
319
+
320
+
321
+ @json_schema_type
322
+ class GetTrainingJobStatusRequest(BaseModel):
323
+ """Request to get the status of a training job."""
324
+
325
+ job_uuid: str = Field(..., description="The UUID of the job to get the status of.")
326
+
327
+
328
+ @json_schema_type
329
+ class CancelTrainingJobRequest(BaseModel):
330
+ """Request to cancel a training job."""
331
+
332
+ job_uuid: str = Field(..., description="The UUID of the job to cancel.")
333
+
334
+
335
+ @json_schema_type
336
+ class GetTrainingJobArtifactsRequest(BaseModel):
337
+ """Request to get the artifacts of a training job."""
338
+
339
+ job_uuid: str = Field(..., description="The UUID of the job to get the artifacts of.")
@@ -0,0 +1,47 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Prompts API protocol and models.
8
+
9
+ This module contains the Prompts protocol definition.
10
+ Pydantic models are defined in llama_stack_api.prompts.models.
11
+ The FastAPI router is defined in llama_stack_api.prompts.fastapi_routes.
12
+ """
13
+
14
+ # Import fastapi_routes for router factory access
15
+ from . import fastapi_routes
16
+
17
+ # Import protocol for FastAPI router
18
+ from .api import Prompts
19
+
20
+ # Import models for re-export
21
+ from .models import (
22
+ CreatePromptRequest,
23
+ DeletePromptRequest,
24
+ GetPromptRequest,
25
+ ListPromptsResponse,
26
+ ListPromptVersionsRequest,
27
+ Prompt,
28
+ SetDefaultVersionBodyRequest,
29
+ SetDefaultVersionRequest,
30
+ UpdatePromptBodyRequest,
31
+ UpdatePromptRequest,
32
+ )
33
+
34
+ __all__ = [
35
+ "CreatePromptRequest",
36
+ "DeletePromptRequest",
37
+ "GetPromptRequest",
38
+ "ListPromptVersionsRequest",
39
+ "ListPromptsResponse",
40
+ "Prompt",
41
+ "Prompts",
42
+ "SetDefaultVersionBodyRequest",
43
+ "SetDefaultVersionRequest",
44
+ "UpdatePromptBodyRequest",
45
+ "UpdatePromptRequest",
46
+ "fastapi_routes",
47
+ ]
@@ -0,0 +1,44 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Prompts API protocol definition.
8
+
9
+ This module contains the Prompts protocol definition.
10
+ Pydantic models are defined in llama_stack_api.prompts.models.
11
+ The FastAPI router is defined in llama_stack_api.prompts.fastapi_routes.
12
+ """
13
+
14
+ from typing import Protocol, runtime_checkable
15
+
16
+ from .models import (
17
+ CreatePromptRequest,
18
+ DeletePromptRequest,
19
+ GetPromptRequest,
20
+ ListPromptsResponse,
21
+ ListPromptVersionsRequest,
22
+ Prompt,
23
+ SetDefaultVersionRequest,
24
+ UpdatePromptRequest,
25
+ )
26
+
27
+
28
+ @runtime_checkable
29
+ class Prompts(Protocol):
30
+ """Protocol for prompt management operations."""
31
+
32
+ async def list_prompts(self) -> ListPromptsResponse: ...
33
+
34
+ async def list_prompt_versions(self, request: ListPromptVersionsRequest) -> ListPromptsResponse: ...
35
+
36
+ async def get_prompt(self, request: GetPromptRequest) -> Prompt: ...
37
+
38
+ async def create_prompt(self, request: CreatePromptRequest) -> Prompt: ...
39
+
40
+ async def update_prompt(self, request: UpdatePromptRequest) -> Prompt: ...
41
+
42
+ async def delete_prompt(self, request: DeletePromptRequest) -> None: ...
43
+
44
+ async def set_default_version(self, request: SetDefaultVersionRequest) -> Prompt: ...