llama-stack-api 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. llama_stack_api/__init__.py +175 -20
  2. llama_stack_api/agents/__init__.py +38 -0
  3. llama_stack_api/agents/api.py +52 -0
  4. llama_stack_api/agents/fastapi_routes.py +268 -0
  5. llama_stack_api/agents/models.py +181 -0
  6. llama_stack_api/common/errors.py +15 -0
  7. llama_stack_api/connectors/__init__.py +38 -0
  8. llama_stack_api/connectors/api.py +50 -0
  9. llama_stack_api/connectors/fastapi_routes.py +103 -0
  10. llama_stack_api/connectors/models.py +103 -0
  11. llama_stack_api/conversations/__init__.py +61 -0
  12. llama_stack_api/conversations/api.py +44 -0
  13. llama_stack_api/conversations/fastapi_routes.py +177 -0
  14. llama_stack_api/conversations/models.py +245 -0
  15. llama_stack_api/datasetio/__init__.py +34 -0
  16. llama_stack_api/datasetio/api.py +42 -0
  17. llama_stack_api/datasetio/fastapi_routes.py +94 -0
  18. llama_stack_api/datasetio/models.py +48 -0
  19. llama_stack_api/eval/__init__.py +55 -0
  20. llama_stack_api/eval/api.py +51 -0
  21. llama_stack_api/eval/compat.py +300 -0
  22. llama_stack_api/eval/fastapi_routes.py +126 -0
  23. llama_stack_api/eval/models.py +141 -0
  24. llama_stack_api/inference/__init__.py +207 -0
  25. llama_stack_api/inference/api.py +93 -0
  26. llama_stack_api/inference/fastapi_routes.py +243 -0
  27. llama_stack_api/inference/models.py +1035 -0
  28. llama_stack_api/models/__init__.py +47 -0
  29. llama_stack_api/models/api.py +38 -0
  30. llama_stack_api/models/fastapi_routes.py +104 -0
  31. llama_stack_api/{models.py → models/models.py} +65 -79
  32. llama_stack_api/openai_responses.py +32 -6
  33. llama_stack_api/post_training/__init__.py +73 -0
  34. llama_stack_api/post_training/api.py +36 -0
  35. llama_stack_api/post_training/fastapi_routes.py +116 -0
  36. llama_stack_api/{post_training.py → post_training/models.py} +55 -86
  37. llama_stack_api/prompts/__init__.py +47 -0
  38. llama_stack_api/prompts/api.py +44 -0
  39. llama_stack_api/prompts/fastapi_routes.py +163 -0
  40. llama_stack_api/prompts/models.py +177 -0
  41. llama_stack_api/resource.py +0 -1
  42. llama_stack_api/safety/__init__.py +37 -0
  43. llama_stack_api/safety/api.py +29 -0
  44. llama_stack_api/safety/datatypes.py +83 -0
  45. llama_stack_api/safety/fastapi_routes.py +55 -0
  46. llama_stack_api/safety/models.py +38 -0
  47. llama_stack_api/schema_utils.py +47 -4
  48. llama_stack_api/scoring/__init__.py +66 -0
  49. llama_stack_api/scoring/api.py +35 -0
  50. llama_stack_api/scoring/fastapi_routes.py +67 -0
  51. llama_stack_api/scoring/models.py +81 -0
  52. llama_stack_api/scoring_functions/__init__.py +50 -0
  53. llama_stack_api/scoring_functions/api.py +39 -0
  54. llama_stack_api/scoring_functions/fastapi_routes.py +108 -0
  55. llama_stack_api/{scoring_functions.py → scoring_functions/models.py} +67 -64
  56. llama_stack_api/shields/__init__.py +41 -0
  57. llama_stack_api/shields/api.py +39 -0
  58. llama_stack_api/shields/fastapi_routes.py +104 -0
  59. llama_stack_api/shields/models.py +74 -0
  60. llama_stack_api/validators.py +46 -0
  61. llama_stack_api/vector_io/__init__.py +88 -0
  62. llama_stack_api/vector_io/api.py +234 -0
  63. llama_stack_api/vector_io/fastapi_routes.py +447 -0
  64. llama_stack_api/{vector_io.py → vector_io/models.py} +99 -377
  65. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/METADATA +1 -1
  66. llama_stack_api-0.5.0rc1.dist-info/RECORD +115 -0
  67. llama_stack_api/agents.py +0 -173
  68. llama_stack_api/connectors.py +0 -146
  69. llama_stack_api/conversations.py +0 -270
  70. llama_stack_api/datasetio.py +0 -55
  71. llama_stack_api/eval.py +0 -137
  72. llama_stack_api/inference.py +0 -1169
  73. llama_stack_api/prompts.py +0 -203
  74. llama_stack_api/safety.py +0 -132
  75. llama_stack_api/scoring.py +0 -93
  76. llama_stack_api/shields.py +0 -93
  77. llama_stack_api-0.4.4.dist-info/RECORD +0 -70
  78. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/WHEEL +0 -0
  79. {llama_stack_api-0.4.4.dist-info → llama_stack_api-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,47 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Models API protocol and models.
8
+
9
+ This module contains the Models protocol definition.
10
+ Pydantic models are defined in llama_stack_api.models.models.
11
+ The FastAPI router is defined in llama_stack_api.models.fastapi_routes.
12
+ """
13
+
14
+ # Import fastapi_routes for router factory access
15
+ from . import fastapi_routes
16
+
17
+ # Import new protocol for FastAPI router
18
+ from .api import Models
19
+
20
+ # Import models for re-export
21
+ from .models import (
22
+ CommonModelFields,
23
+ GetModelRequest,
24
+ ListModelsResponse,
25
+ Model,
26
+ ModelInput,
27
+ ModelType,
28
+ OpenAIListModelsResponse,
29
+ OpenAIModel,
30
+ RegisterModelRequest,
31
+ UnregisterModelRequest,
32
+ )
33
+
34
+ __all__ = [
35
+ "CommonModelFields",
36
+ "fastapi_routes",
37
+ "GetModelRequest",
38
+ "ListModelsResponse",
39
+ "Model",
40
+ "ModelInput",
41
+ "Models",
42
+ "ModelType",
43
+ "OpenAIListModelsResponse",
44
+ "OpenAIModel",
45
+ "RegisterModelRequest",
46
+ "UnregisterModelRequest",
47
+ ]
@@ -0,0 +1,38 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Models API protocol definition.
8
+
9
+ This module contains the Models protocol definition.
10
+ Pydantic models are defined in llama_stack_api.models.models.
11
+ The FastAPI router is defined in llama_stack_api.models.fastapi_routes.
12
+ """
13
+
14
+ from typing import Protocol, runtime_checkable
15
+
16
+ from .models import (
17
+ GetModelRequest,
18
+ ListModelsResponse,
19
+ Model,
20
+ OpenAIListModelsResponse,
21
+ RegisterModelRequest,
22
+ UnregisterModelRequest,
23
+ )
24
+
25
+
26
+ @runtime_checkable
27
+ class Models(Protocol):
28
+ """Protocol for model management operations."""
29
+
30
+ async def list_models(self) -> ListModelsResponse: ...
31
+
32
+ async def openai_list_models(self) -> OpenAIListModelsResponse: ...
33
+
34
+ async def get_model(self, request: GetModelRequest) -> Model: ...
35
+
36
+ async def register_model(self, request: RegisterModelRequest) -> Model: ...
37
+
38
+ async def unregister_model(self, request: UnregisterModelRequest) -> None: ...
@@ -0,0 +1,104 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """FastAPI router for the Models API.
8
+
9
+ This module defines the FastAPI router for the Models API using standard
10
+ FastAPI route decorators.
11
+ """
12
+
13
+ from typing import Annotated
14
+
15
+ from fastapi import APIRouter, Body, Depends
16
+
17
+ from llama_stack_api.router_utils import create_path_dependency, standard_responses
18
+ from llama_stack_api.version import LLAMA_STACK_API_V1
19
+
20
+ from .api import Models
21
+ from .models import (
22
+ GetModelRequest,
23
+ Model,
24
+ OpenAIListModelsResponse,
25
+ RegisterModelRequest,
26
+ UnregisterModelRequest,
27
+ )
28
+
29
+ # Path parameter dependencies for single-field models
30
+ get_model_request = create_path_dependency(GetModelRequest)
31
+ unregister_model_request = create_path_dependency(UnregisterModelRequest)
32
+
33
+
34
+ def create_router(impl: Models) -> APIRouter:
35
+ """Create a FastAPI router for the Models API.
36
+
37
+ Args:
38
+ impl: The Models implementation instance
39
+
40
+ Returns:
41
+ APIRouter configured for the Models API
42
+ """
43
+ router = APIRouter(
44
+ prefix=f"/{LLAMA_STACK_API_V1}",
45
+ tags=["Models"],
46
+ responses=standard_responses,
47
+ )
48
+
49
+ @router.get(
50
+ "/models",
51
+ response_model=OpenAIListModelsResponse,
52
+ summary="List models using the OpenAI API.",
53
+ description="List models using the OpenAI API.",
54
+ responses={
55
+ 200: {"description": "A list of OpenAI model objects."},
56
+ },
57
+ )
58
+ async def openai_list_models() -> OpenAIListModelsResponse:
59
+ return await impl.openai_list_models()
60
+
61
+ @router.get(
62
+ "/models/{model_id:path}",
63
+ response_model=Model,
64
+ summary="Get a model by its identifier.",
65
+ description="Get a model by its identifier.",
66
+ responses={
67
+ 200: {"description": "The model object."},
68
+ },
69
+ )
70
+ async def get_model(
71
+ request: Annotated[GetModelRequest, Depends(get_model_request)],
72
+ ) -> Model:
73
+ return await impl.get_model(request)
74
+
75
+ @router.post(
76
+ "/models",
77
+ response_model=Model,
78
+ summary="Register a model.",
79
+ description="Register a model.",
80
+ responses={
81
+ 200: {"description": "The registered model object."},
82
+ },
83
+ deprecated=True,
84
+ )
85
+ async def register_model(
86
+ request: Annotated[RegisterModelRequest, Body(...)],
87
+ ) -> Model:
88
+ return await impl.register_model(request)
89
+
90
+ @router.delete(
91
+ "/models/{model_id:path}",
92
+ summary="Unregister a model.",
93
+ description="Unregister a model.",
94
+ responses={
95
+ 200: {"description": "The model was successfully unregistered."},
96
+ },
97
+ deprecated=True,
98
+ )
99
+ async def unregister_model(
100
+ request: Annotated[UnregisterModelRequest, Depends(unregister_model_request)],
101
+ ) -> None:
102
+ return await impl.unregister_model(request)
103
+
104
+ return router
@@ -4,26 +4,25 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
+ """Pydantic models for Models API requests and responses.
8
+
9
+ This module defines the request and response models for the Models API
10
+ using Pydantic with Field descriptions for OpenAPI schema generation.
11
+ """
12
+
7
13
  from enum import StrEnum
8
- from typing import Any, Literal, Protocol, runtime_checkable
14
+ from typing import Any, Literal
9
15
 
10
16
  from pydantic import BaseModel, ConfigDict, Field, field_validator
11
17
 
12
18
  from llama_stack_api.resource import Resource, ResourceType
13
- from llama_stack_api.schema_utils import json_schema_type, webmethod
14
- from llama_stack_api.version import LLAMA_STACK_API_V1
15
-
16
-
17
- class CommonModelFields(BaseModel):
18
- metadata: dict[str, Any] = Field(
19
- default_factory=dict,
20
- description="Any additional metadata for this model",
21
- )
19
+ from llama_stack_api.schema_utils import json_schema_type
22
20
 
23
21
 
24
22
  @json_schema_type
25
23
  class ModelType(StrEnum):
26
24
  """Enumeration of supported model types in Llama Stack.
25
+
27
26
  :cvar llm: Large language model for text generation and completion
28
27
  :cvar embedding: Embedding model for converting text to vector representations
29
28
  :cvar rerank: Reranking model for reordering documents based on their relevance to a query
@@ -34,6 +33,13 @@ class ModelType(StrEnum):
34
33
  rerank = "rerank"
35
34
 
36
35
 
36
+ class CommonModelFields(BaseModel):
37
+ metadata: dict[str, Any] = Field(
38
+ default_factory=dict,
39
+ description="Any additional metadata for this model",
40
+ )
41
+
42
+
37
43
  @json_schema_type
38
44
  class Model(CommonModelFields, Resource):
39
45
  """A model resource representing an AI model registered in Llama Stack.
@@ -77,8 +83,11 @@ class ModelInput(CommonModelFields):
77
83
  model_config = ConfigDict(protected_namespaces=())
78
84
 
79
85
 
86
+ @json_schema_type
80
87
  class ListModelsResponse(BaseModel):
81
- data: list[Model]
88
+ """Response containing a list of model objects."""
89
+
90
+ data: list[Model] = Field(..., description="List of model objects.")
82
91
 
83
92
 
84
93
  @json_schema_type
@@ -101,71 +110,48 @@ class OpenAIModel(BaseModel):
101
110
 
102
111
  @json_schema_type
103
112
  class OpenAIListModelsResponse(BaseModel):
104
- data: list[OpenAIModel]
105
-
106
-
107
- @runtime_checkable
108
- class Models(Protocol):
109
- async def list_models(self) -> ListModelsResponse:
110
- """List all models.
111
-
112
- :returns: A ListModelsResponse.
113
- """
114
- ...
115
-
116
- @webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
117
- async def openai_list_models(self) -> OpenAIListModelsResponse:
118
- """List models using the OpenAI API.
119
-
120
- :returns: A OpenAIListModelsResponse.
121
- """
122
- ...
123
-
124
- @webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
125
- async def get_model(
126
- self,
127
- model_id: str,
128
- ) -> Model:
129
- """Get model.
130
-
131
- Get a model by its identifier.
132
-
133
- :param model_id: The identifier of the model to get.
134
- :returns: A Model.
135
- """
136
- ...
137
-
138
- @webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
139
- async def register_model(
140
- self,
141
- model_id: str,
142
- provider_model_id: str | None = None,
143
- provider_id: str | None = None,
144
- metadata: dict[str, Any] | None = None,
145
- model_type: ModelType | None = None,
146
- ) -> Model:
147
- """Register model.
148
-
149
- Register a model.
150
-
151
- :param model_id: The identifier of the model to register.
152
- :param provider_model_id: The identifier of the model in the provider.
153
- :param provider_id: The identifier of the provider.
154
- :param metadata: Any additional metadata for this model.
155
- :param model_type: The type of model to register.
156
- :returns: A Model.
157
- """
158
- ...
159
-
160
- @webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
161
- async def unregister_model(
162
- self,
163
- model_id: str,
164
- ) -> None:
165
- """Unregister model.
166
-
167
- Unregister a model.
168
-
169
- :param model_id: The identifier of the model to unregister.
170
- """
171
- ...
113
+ """Response containing a list of OpenAI model objects."""
114
+
115
+ data: list[OpenAIModel] = Field(..., description="List of OpenAI model objects.")
116
+
117
+
118
+ # Request models for each endpoint
119
+
120
+
121
+ @json_schema_type
122
+ class GetModelRequest(BaseModel):
123
+ """Request model for getting a model by ID."""
124
+
125
+ model_id: str = Field(..., description="The ID of the model to get.")
126
+
127
+
128
+ @json_schema_type
129
+ class RegisterModelRequest(BaseModel):
130
+ """Request model for registering a model."""
131
+
132
+ model_id: str = Field(..., description="The identifier of the model to register.")
133
+ provider_model_id: str | None = Field(default=None, description="The identifier of the model in the provider.")
134
+ provider_id: str | None = Field(default=None, description="The identifier of the provider.")
135
+ metadata: dict[str, Any] | None = Field(default=None, description="Any additional metadata for this model.")
136
+ model_type: ModelType | None = Field(default=None, description="The type of model to register.")
137
+
138
+
139
+ @json_schema_type
140
+ class UnregisterModelRequest(BaseModel):
141
+ """Request model for unregistering a model."""
142
+
143
+ model_id: str = Field(..., description="The ID of the model to unregister.")
144
+
145
+
146
+ __all__ = [
147
+ "CommonModelFields",
148
+ "GetModelRequest",
149
+ "ListModelsResponse",
150
+ "Model",
151
+ "ModelInput",
152
+ "ModelType",
153
+ "OpenAIListModelsResponse",
154
+ "OpenAIModel",
155
+ "RegisterModelRequest",
156
+ "UnregisterModelRequest",
157
+ ]
@@ -405,6 +405,19 @@ class OpenAIResponseText(BaseModel):
405
405
  format: OpenAIResponseTextFormat | None = None
406
406
 
407
407
 
408
+ @json_schema_type
409
+ class OpenAIResponseReasoning(BaseModel):
410
+ """Configuration for reasoning effort in OpenAI responses.
411
+
412
+ Controls how much reasoning the model performs before generating a response.
413
+
414
+ :param effort: The effort level for reasoning. "low" favors speed and economical token usage,
415
+ "high" favors more complete reasoning, "medium" is a balance between the two.
416
+ """
417
+
418
+ effort: Literal["none", "minimal", "low", "medium", "high", "xhigh"] | None = None
419
+
420
+
408
421
  # Must match type Literals of OpenAIResponseInputToolWebSearch below
409
422
  WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_2025_03_11", "web_search_2025_08_26"]
410
423
 
@@ -491,7 +504,8 @@ class OpenAIResponseInputToolMCP(BaseModel):
491
504
 
492
505
  :param type: Tool type identifier, always "mcp"
493
506
  :param server_label: Label to identify this MCP server
494
- :param server_url: URL endpoint of the MCP server
507
+ :param connector_id: (Optional) ID of the connector to use for this MCP server
508
+ :param server_url: (Optional) URL endpoint of the MCP server
495
509
  :param headers: (Optional) HTTP headers to include when connecting to the server
496
510
  :param authorization: (Optional) OAuth access token for authenticating with the MCP server
497
511
  :param require_approval: Approval requirement for tool calls ("always", "never", or filter)
@@ -500,13 +514,20 @@ class OpenAIResponseInputToolMCP(BaseModel):
500
514
 
501
515
  type: Literal["mcp"] = "mcp"
502
516
  server_label: str
503
- server_url: str
517
+ connector_id: str | None = None
518
+ server_url: str | None = None
504
519
  headers: dict[str, Any] | None = None
505
520
  authorization: str | None = Field(default=None, exclude=True)
506
521
 
507
522
  require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
508
523
  allowed_tools: list[str] | AllowedToolsFilter | None = None
509
524
 
525
+ @model_validator(mode="after")
526
+ def validate_server_or_connector(self) -> "OpenAIResponseInputToolMCP":
527
+ if not self.server_url and not self.connector_id:
528
+ raise ValueError("Either 'server_url' or 'connector_id' must be provided for MCP tool")
529
+ return self
530
+
510
531
 
511
532
  OpenAIResponseInputTool = Annotated[
512
533
  OpenAIResponseInputToolWebSearch
@@ -647,7 +668,7 @@ class OpenAIResponseUsageOutputTokensDetails(BaseModel):
647
668
  :param reasoning_tokens: Number of tokens used for reasoning (o1/o3 models)
648
669
  """
649
670
 
650
- reasoning_tokens: int | None = None
671
+ reasoning_tokens: int
651
672
 
652
673
 
653
674
  class OpenAIResponseUsageInputTokensDetails(BaseModel):
@@ -656,7 +677,7 @@ class OpenAIResponseUsageInputTokensDetails(BaseModel):
656
677
  :param cached_tokens: Number of tokens retrieved from cache
657
678
  """
658
679
 
659
- cached_tokens: int | None = None
680
+ cached_tokens: int
660
681
 
661
682
 
662
683
  @json_schema_type
@@ -673,8 +694,8 @@ class OpenAIResponseUsage(BaseModel):
673
694
  input_tokens: int
674
695
  output_tokens: int
675
696
  total_tokens: int
676
- input_tokens_details: OpenAIResponseUsageInputTokensDetails | None = None
677
- output_tokens_details: OpenAIResponseUsageOutputTokensDetails | None = None
697
+ input_tokens_details: OpenAIResponseUsageInputTokensDetails
698
+ output_tokens_details: OpenAIResponseUsageOutputTokensDetails
678
699
 
679
700
 
680
701
  @json_schema_type
@@ -700,10 +721,12 @@ class OpenAIResponseObject(BaseModel):
700
721
  :param usage: (Optional) Token usage information for the response
701
722
  :param instructions: (Optional) System message inserted into the model's context
702
723
  :param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
724
+ :param max_output_tokens: (Optional) An upper bound for the number of tokens that can be generated for a response, including visible output tokens.
703
725
  :param metadata: (Optional) Dictionary of metadata key-value pairs
704
726
  """
705
727
 
706
728
  created_at: int
729
+ completed_at: int | None = None
707
730
  error: OpenAIResponseError | None = None
708
731
  id: str
709
732
  model: str
@@ -724,7 +747,10 @@ class OpenAIResponseObject(BaseModel):
724
747
  usage: OpenAIResponseUsage | None = None
725
748
  instructions: str | None = None
726
749
  max_tool_calls: int | None = None
750
+ reasoning: OpenAIResponseReasoning | None = None
751
+ max_output_tokens: int | None = None
727
752
  metadata: dict[str, str] | None = None
753
+ store: bool
728
754
 
729
755
 
730
756
  @json_schema_type
@@ -0,0 +1,73 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ """Post-Training API protocol and models.
8
+
9
+ This module contains the Post-Training protocol definition.
10
+ Pydantic models are defined in llama_stack_api.post_training.models.
11
+ The FastAPI router is defined in llama_stack_api.post_training.fastapi_routes.
12
+ """
13
+
14
+ # Import fastapi_routes for router factory access
15
+ from . import fastapi_routes
16
+
17
+ # Import protocol for re-export
18
+ from .api import PostTraining
19
+
20
+ # Import models for re-export
21
+ from .models import (
22
+ AlgorithmConfig,
23
+ CancelTrainingJobRequest,
24
+ DataConfig,
25
+ DatasetFormat,
26
+ DPOAlignmentConfig,
27
+ DPOLossType,
28
+ EfficiencyConfig,
29
+ GetTrainingJobArtifactsRequest,
30
+ GetTrainingJobStatusRequest,
31
+ ListPostTrainingJobsResponse,
32
+ LoraFinetuningConfig,
33
+ OptimizerConfig,
34
+ OptimizerType,
35
+ PostTrainingJob,
36
+ PostTrainingJobArtifactsResponse,
37
+ PostTrainingJobLogStream,
38
+ PostTrainingJobStatusResponse,
39
+ PostTrainingRLHFRequest,
40
+ PreferenceOptimizeRequest,
41
+ QATFinetuningConfig,
42
+ RLHFAlgorithm,
43
+ SupervisedFineTuneRequest,
44
+ TrainingConfig,
45
+ )
46
+
47
+ __all__ = [
48
+ "PostTraining",
49
+ "AlgorithmConfig",
50
+ "CancelTrainingJobRequest",
51
+ "DataConfig",
52
+ "DatasetFormat",
53
+ "DPOAlignmentConfig",
54
+ "DPOLossType",
55
+ "EfficiencyConfig",
56
+ "GetTrainingJobArtifactsRequest",
57
+ "GetTrainingJobStatusRequest",
58
+ "ListPostTrainingJobsResponse",
59
+ "LoraFinetuningConfig",
60
+ "OptimizerConfig",
61
+ "OptimizerType",
62
+ "PostTrainingJob",
63
+ "PostTrainingJobArtifactsResponse",
64
+ "PostTrainingJobLogStream",
65
+ "PostTrainingJobStatusResponse",
66
+ "PostTrainingRLHFRequest",
67
+ "PreferenceOptimizeRequest",
68
+ "QATFinetuningConfig",
69
+ "RLHFAlgorithm",
70
+ "SupervisedFineTuneRequest",
71
+ "TrainingConfig",
72
+ "fastapi_routes",
73
+ ]
@@ -0,0 +1,36 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the terms described in the LICENSE file in
5
+ # the root directory of this source tree.
6
+
7
+ from typing import Protocol, runtime_checkable
8
+
9
+ from .models import (
10
+ CancelTrainingJobRequest,
11
+ GetTrainingJobArtifactsRequest,
12
+ GetTrainingJobStatusRequest,
13
+ ListPostTrainingJobsResponse,
14
+ PostTrainingJob,
15
+ PostTrainingJobArtifactsResponse,
16
+ PostTrainingJobStatusResponse,
17
+ PreferenceOptimizeRequest,
18
+ SupervisedFineTuneRequest,
19
+ )
20
+
21
+
22
+ @runtime_checkable
23
+ class PostTraining(Protocol):
24
+ async def supervised_fine_tune(self, request: SupervisedFineTuneRequest) -> PostTrainingJob: ...
25
+
26
+ async def preference_optimize(self, request: PreferenceOptimizeRequest) -> PostTrainingJob: ...
27
+
28
+ async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ...
29
+
30
+ async def get_training_job_status(self, request: GetTrainingJobStatusRequest) -> PostTrainingJobStatusResponse: ...
31
+
32
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None: ...
33
+
34
+ async def get_training_job_artifacts(
35
+ self, request: GetTrainingJobArtifactsRequest
36
+ ) -> PostTrainingJobArtifactsResponse: ...