seekrai 0.5.26__tar.gz → 0.5.29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {seekrai-0.5.26 → seekrai-0.5.29}/PKG-INFO +1 -1
- {seekrai-0.5.26 → seekrai-0.5.29}/pyproject.toml +1 -1
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/__init__.py +11 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/__init__.py +2 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/__init__.py +2 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/__init__.py +2 -0
- seekrai-0.5.29/src/seekrai/types/agents/tools/schemas/agent_as_tool.py +16 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/tool_types.py +5 -1
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/enums.py +1 -0
- seekrai-0.5.29/src/seekrai/types/finetune.py +459 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/tools.py +22 -3
- seekrai-0.5.26/src/seekrai/types/finetune.py +0 -256
- {seekrai-0.5.26 → seekrai-0.5.29}/LICENSE +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/README.md +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/__init__.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/abstract/__init__.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/abstract/api_requestor.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/abstract/response_parsing.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/client.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/constants.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/error.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/filemanager.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/__init__.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/__init__.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/agent_inference.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/agent_observability.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/agents.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/python_functions.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/threads.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/alignment.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/chat/__init__.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/chat/completions.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/completions.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/deployments.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/embeddings.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/explainability.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/files.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/finetune.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/images.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/ingestion.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/models.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/projects.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/resource_base.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/tools.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/vectordb.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/seekrflow_response.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/abstract.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/agent.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/observability.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/python_functions.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/runs.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/threads.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/env_model_config.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/file_search.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/file_search_env.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/run_python.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/run_python_env.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/web_search.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/web_search_env.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/tool.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/alignment.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/chat_completions.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/common.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/completions.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/deployments.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/embeddings.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/error.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/explainability.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/files.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/images.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/ingestion.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/models.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/projects.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/vectordb.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/__init__.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/_log.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/api_helpers.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/files.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/tools.py +0 -0
- {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/version.py +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from seekrai.types.abstract import SeekrFlowClient
|
|
2
2
|
from seekrai.types.agents import (
|
|
3
3
|
Agent,
|
|
4
|
+
AgentAsToolLegacy,
|
|
4
5
|
AgentDeleteResponse,
|
|
5
6
|
CreateAgentRequest,
|
|
6
7
|
EnvConfig,
|
|
@@ -47,6 +48,7 @@ from seekrai.types.agents import (
|
|
|
47
48
|
WebSearchEnv,
|
|
48
49
|
)
|
|
49
50
|
from seekrai.types.agents.tools.schemas import (
|
|
51
|
+
AgentAsToolLegacy,
|
|
50
52
|
FileSearch,
|
|
51
53
|
FileSearchEnv,
|
|
52
54
|
RunPython,
|
|
@@ -120,6 +122,9 @@ from seekrai.types.projects import (
|
|
|
120
122
|
ProjectWithRuns,
|
|
121
123
|
)
|
|
122
124
|
from seekrai.types.tools import (
|
|
125
|
+
AgentAsTool,
|
|
126
|
+
AgentAsToolConfig,
|
|
127
|
+
CreateAgentAsTool,
|
|
123
128
|
CreateFileSearch,
|
|
124
129
|
CreateRunPython,
|
|
125
130
|
CreateToolRequest,
|
|
@@ -132,6 +137,7 @@ from seekrai.types.tools import (
|
|
|
132
137
|
ToolAgentSummaryResponse,
|
|
133
138
|
ToolDeleteResponse,
|
|
134
139
|
ToolResponse,
|
|
140
|
+
UpdateAgentAsTool,
|
|
135
141
|
UpdateFileSearch,
|
|
136
142
|
UpdateRunPython,
|
|
137
143
|
UpdateToolRequest,
|
|
@@ -258,4 +264,9 @@ __all__ = [
|
|
|
258
264
|
"RunPythonEnv",
|
|
259
265
|
"WebSearch",
|
|
260
266
|
"WebSearchEnv",
|
|
267
|
+
"AgentAsToolLegacy",
|
|
268
|
+
"AgentAsTool",
|
|
269
|
+
"CreateAgentAsTool",
|
|
270
|
+
"AgentAsToolConfig",
|
|
271
|
+
"UpdateAgentAsTool",
|
|
261
272
|
]
|
|
@@ -51,6 +51,7 @@ from seekrai.types.agents.tools import (
|
|
|
51
51
|
ToolBase,
|
|
52
52
|
)
|
|
53
53
|
from seekrai.types.agents.tools.schemas import (
|
|
54
|
+
AgentAsToolLegacy,
|
|
54
55
|
FileSearch,
|
|
55
56
|
FileSearchEnv,
|
|
56
57
|
RunPython,
|
|
@@ -106,6 +107,7 @@ __all__ = [
|
|
|
106
107
|
"RunPythonEnv",
|
|
107
108
|
"WebSearch",
|
|
108
109
|
"WebSearchEnv",
|
|
110
|
+
"AgentAsToolLegacy",
|
|
109
111
|
"PythonFunctionBase",
|
|
110
112
|
"PythonFunctionResponse",
|
|
111
113
|
"DeletePythonFunctionResponse",
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from seekrai.types.agents.tools.env_model_config import EnvConfig
|
|
2
2
|
from seekrai.types.agents.tools.schemas import (
|
|
3
|
+
AgentAsToolLegacy,
|
|
3
4
|
FileSearch,
|
|
4
5
|
FileSearchEnv,
|
|
5
6
|
RunPython,
|
|
@@ -23,4 +24,5 @@ __all__ = [
|
|
|
23
24
|
"RunPythonEnv",
|
|
24
25
|
"WebSearch",
|
|
25
26
|
"WebSearchEnv",
|
|
27
|
+
"AgentAsToolLegacy",
|
|
26
28
|
]
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from seekrai.types.agents.tools.schemas.agent_as_tool import AgentAsToolLegacy
|
|
1
2
|
from seekrai.types.agents.tools.schemas.file_search import FileSearch
|
|
2
3
|
from seekrai.types.agents.tools.schemas.file_search_env import FileSearchEnv
|
|
3
4
|
from seekrai.types.agents.tools.schemas.run_python import RunPython
|
|
@@ -13,4 +14,5 @@ __all__ = [
|
|
|
13
14
|
"RunPythonEnv",
|
|
14
15
|
"WebSearch",
|
|
15
16
|
"WebSearchEnv",
|
|
17
|
+
"AgentAsToolLegacy",
|
|
16
18
|
]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
from seekrai.types.agents.tools import EnvConfig
|
|
4
|
+
from seekrai.types.agents.tools.tool import ToolBase, ToolType
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AgentAsToolLegacy(ToolBase[Literal["agent_as_tool"], EnvConfig]):
|
|
8
|
+
name: Literal["agent_as_tool"] = ToolType.AGENT_AS_TOOL.value
|
|
9
|
+
description: str
|
|
10
|
+
agent_id: str
|
|
11
|
+
|
|
12
|
+
model_config = {
|
|
13
|
+
"json_schema_extra": {
|
|
14
|
+
"deprecated": True,
|
|
15
|
+
}
|
|
16
|
+
}
|
|
@@ -2,9 +2,13 @@ from typing import Annotated, Union
|
|
|
2
2
|
|
|
3
3
|
from pydantic import Field
|
|
4
4
|
|
|
5
|
+
from seekrai.types.agents.tools.schemas.agent_as_tool import AgentAsToolLegacy
|
|
5
6
|
from seekrai.types.agents.tools.schemas.file_search import FileSearch
|
|
6
7
|
from seekrai.types.agents.tools.schemas.run_python import RunPython
|
|
7
8
|
from seekrai.types.agents.tools.schemas.web_search import WebSearch
|
|
8
9
|
|
|
9
10
|
|
|
10
|
-
Tool = Annotated[
|
|
11
|
+
Tool = Annotated[
|
|
12
|
+
Union[FileSearch, WebSearch, RunPython, AgentAsToolLegacy],
|
|
13
|
+
Field(discriminator="name"),
|
|
14
|
+
]
|
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any, Callable, Dict, List, Literal, Optional
|
|
7
|
+
|
|
8
|
+
from pydantic import Field, field_validator, model_serializer, model_validator
|
|
9
|
+
|
|
10
|
+
from seekrai.types.abstract import BaseModel
|
|
11
|
+
from seekrai.types.common import (
|
|
12
|
+
ObjectType,
|
|
13
|
+
)
|
|
14
|
+
from seekrai.utils._log import log_info
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FinetuneJobStatus(str, Enum):
|
|
18
|
+
"""
|
|
19
|
+
Possible fine-tune job status
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
STATUS_PENDING = "pending"
|
|
23
|
+
STATUS_QUEUED = "queued"
|
|
24
|
+
STATUS_RUNNING = "running"
|
|
25
|
+
# STATUS_COMPRESSING = "compressing"
|
|
26
|
+
# STATUS_UPLOADING = "uploading"
|
|
27
|
+
STATUS_CANCEL_REQUESTED = "cancel_requested"
|
|
28
|
+
STATUS_CANCELLED = "cancelled"
|
|
29
|
+
STATUS_FAILED = "failed"
|
|
30
|
+
STATUS_COMPLETED = "completed"
|
|
31
|
+
STATUS_DELETED = "deleted"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class FinetuneEventLevels(str, Enum):
|
|
35
|
+
"""
|
|
36
|
+
Fine-tune job event status levels
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
NULL = ""
|
|
40
|
+
INFO = "Info"
|
|
41
|
+
WARNING = "Warning"
|
|
42
|
+
ERROR = "Error"
|
|
43
|
+
LEGACY_INFO = "info"
|
|
44
|
+
LEGACY_IWARNING = "warning"
|
|
45
|
+
LEGACY_IERROR = "error"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FinetuneEventType(str, Enum):
|
|
49
|
+
"""
|
|
50
|
+
Fine-tune job event types
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
JOB_PENDING = "JOB_PENDING"
|
|
54
|
+
JOB_START = "JOB_START"
|
|
55
|
+
JOB_STOPPED = "JOB_STOPPED"
|
|
56
|
+
MODEL_DOWNLOADING = "MODEL_DOWNLOADING"
|
|
57
|
+
MODEL_DOWNLOAD_COMPLETE = "MODEL_DOWNLOAD_COMPLETE"
|
|
58
|
+
TRAINING_DATA_DOWNLOADING = "TRAINING_DATA_DOWNLOADING"
|
|
59
|
+
TRAINING_DATA_DOWNLOAD_COMPLETE = "TRAINING_DATA_DOWNLOAD_COMPLETE"
|
|
60
|
+
VALIDATION_DATA_DOWNLOADING = "VALIDATION_DATA_DOWNLOADING"
|
|
61
|
+
VALIDATION_DATA_DOWNLOAD_COMPLETE = "VALIDATION_DATA_DOWNLOAD_COMPLETE"
|
|
62
|
+
WANDB_INIT = "WANDB_INIT"
|
|
63
|
+
TRAINING_START = "TRAINING_START"
|
|
64
|
+
CHECKPOINT_SAVE = "CHECKPOINT_SAVE"
|
|
65
|
+
BILLING_LIMIT = "BILLING_LIMIT"
|
|
66
|
+
EPOCH_COMPLETE = "EPOCH_COMPLETE"
|
|
67
|
+
TRAINING_COMPLETE = "TRAINING_COMPLETE"
|
|
68
|
+
MODEL_COMPRESSING = "COMPRESSING_MODEL"
|
|
69
|
+
MODEL_COMPRESSION_COMPLETE = "MODEL_COMPRESSION_COMPLETE"
|
|
70
|
+
MODEL_UPLOADING = "MODEL_UPLOADING"
|
|
71
|
+
MODEL_UPLOAD_COMPLETE = "MODEL_UPLOAD_COMPLETE"
|
|
72
|
+
JOB_COMPLETE = "JOB_COMPLETE"
|
|
73
|
+
JOB_ERROR = "JOB_ERROR"
|
|
74
|
+
CANCEL_REQUESTED = "CANCEL_REQUESTED"
|
|
75
|
+
JOB_RESTARTED = "JOB_RESTARTED"
|
|
76
|
+
REFUND = "REFUND"
|
|
77
|
+
WARNING = "WARNING"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class FineTuneType(str, Enum):
|
|
81
|
+
STANDARD = "STANDARD"
|
|
82
|
+
GRPO = "GRPO" # deprecated
|
|
83
|
+
PREFERENCE = "PREFERENCE"
|
|
84
|
+
REINFORCEMENT = "REINFORCEMENT"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class GraderType(str, Enum):
|
|
88
|
+
FORMAT_CHECK = "format_check"
|
|
89
|
+
MATH_ACCURACY = "math_accuracy"
|
|
90
|
+
STRING_CHECK = "string_check"
|
|
91
|
+
TEXT_SIMILARITY = "text_similarity"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class StringOperation(str, Enum):
|
|
95
|
+
EQUALS = "equals"
|
|
96
|
+
NOT_EQUALS = "not_equals"
|
|
97
|
+
CONTAINS = "contains"
|
|
98
|
+
CASE_INSENSITIVE_CONTAINS = "case_insensitive_contains"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class TextSimilarityOperation(str, Enum):
|
|
102
|
+
BLEU = "bleu"
|
|
103
|
+
ROUGE = "rouge"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class FinetuneEvent(BaseModel):
|
|
107
|
+
"""
|
|
108
|
+
Fine-tune event type
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# object type
|
|
112
|
+
object: Literal[ObjectType.FinetuneEvent]
|
|
113
|
+
# created at datetime stamp
|
|
114
|
+
created_at: datetime | None = None
|
|
115
|
+
# metrics that we expose
|
|
116
|
+
loss: float | None = None
|
|
117
|
+
epoch: float | None = None
|
|
118
|
+
reward: float | None = None
|
|
119
|
+
|
|
120
|
+
@model_serializer(mode="wrap")
|
|
121
|
+
def serialize_model(
|
|
122
|
+
self, handler: Callable[[Any], dict[str, Any]]
|
|
123
|
+
) -> dict[str, Any]:
|
|
124
|
+
# Remove 'reward' if it's None
|
|
125
|
+
dump_dict = handler(self)
|
|
126
|
+
if dump_dict.get("reward") is None:
|
|
127
|
+
del dump_dict["reward"]
|
|
128
|
+
return dump_dict
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class LoRAConfig(BaseModel):
|
|
132
|
+
r: int = Field(8, gt=0, description="Rank of the update matrices.")
|
|
133
|
+
alpha: int = Field(32, gt=0, description="Scaling factor applied to LoRA updates.")
|
|
134
|
+
dropout: float = Field(
|
|
135
|
+
0.1,
|
|
136
|
+
ge=0.0,
|
|
137
|
+
le=1.0,
|
|
138
|
+
description="Fraction of LoRA neurons dropped during training.",
|
|
139
|
+
)
|
|
140
|
+
bias: Literal["none", "all", "lora_only"] = Field(
|
|
141
|
+
"none",
|
|
142
|
+
description="Bias terms to train; choose from 'none', 'all', or 'lora_only'.",
|
|
143
|
+
)
|
|
144
|
+
extras: Dict[str, Any] = Field(default_factory=dict)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class Grader(BaseModel):
|
|
148
|
+
type: GraderType
|
|
149
|
+
weight: float | None = Field(default=None, gt=0.0, le=1.0)
|
|
150
|
+
operation: StringOperation | TextSimilarityOperation | None = Field(default=None)
|
|
151
|
+
|
|
152
|
+
@model_validator(mode="before")
|
|
153
|
+
@classmethod
|
|
154
|
+
def validate_operation(cls, data: Any) -> Any:
|
|
155
|
+
if not isinstance(data, dict):
|
|
156
|
+
return data
|
|
157
|
+
|
|
158
|
+
grader_type = data.get("type")
|
|
159
|
+
operation_value = data.get("operation")
|
|
160
|
+
|
|
161
|
+
if grader_type == GraderType.STRING_CHECK:
|
|
162
|
+
if not operation_value:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
"string_check grader is missing required StringOperation"
|
|
165
|
+
)
|
|
166
|
+
if isinstance(operation_value, str):
|
|
167
|
+
try:
|
|
168
|
+
# Convert to enum to validate it's a valid value
|
|
169
|
+
data["operation"] = StringOperation(operation_value.lower())
|
|
170
|
+
except ValueError:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
f"Invalid operation for string_check grader: "
|
|
173
|
+
f"expected StringOperation, but got type '{type(operation_value).__name__}' with value '{operation_value}'"
|
|
174
|
+
)
|
|
175
|
+
elif grader_type == GraderType.TEXT_SIMILARITY:
|
|
176
|
+
if not operation_value:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"text_similarity grader is missing required TextSimilarityOperation"
|
|
179
|
+
)
|
|
180
|
+
if isinstance(operation_value, str):
|
|
181
|
+
try:
|
|
182
|
+
data["operation"] = TextSimilarityOperation(operation_value.lower())
|
|
183
|
+
except ValueError:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"Invalid operation for text_similarity grader: "
|
|
186
|
+
f"expected TextSimilarityOperation, got type '{type(operation_value).__name__}' with value '{operation_value}'"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
elif grader_type in (GraderType.FORMAT_CHECK, GraderType.MATH_ACCURACY):
|
|
190
|
+
if operation_value:
|
|
191
|
+
raise ValueError(f"{grader_type} grader cannot have an operation")
|
|
192
|
+
data["operation"] = None
|
|
193
|
+
|
|
194
|
+
return data
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class RewardComponents(BaseModel):
|
|
198
|
+
format_reward_weight: float = Field(default=0.1, gt=0.0, le=1.0)
|
|
199
|
+
graders: list[Grader] = Field(min_length=1)
|
|
200
|
+
|
|
201
|
+
@model_validator(mode="after")
|
|
202
|
+
def validate_weights(self) -> "RewardComponents":
|
|
203
|
+
is_format_weight_specified = "format_reward_weight" in self.model_fields_set
|
|
204
|
+
|
|
205
|
+
grader_weights_specified = [
|
|
206
|
+
grader.weight is not None for grader in self.graders
|
|
207
|
+
]
|
|
208
|
+
|
|
209
|
+
all_graders_have_weights = all(grader_weights_specified)
|
|
210
|
+
some_graders_have_weights = any(grader_weights_specified) and not all(
|
|
211
|
+
grader_weights_specified
|
|
212
|
+
)
|
|
213
|
+
no_graders_have_weights = not any(grader_weights_specified)
|
|
214
|
+
|
|
215
|
+
if some_graders_have_weights:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"Only some graders have weights specified. Either all graders must have weights specified, or none of them."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if all_graders_have_weights and is_format_weight_specified:
|
|
221
|
+
self._validate_weights_sum_to_one()
|
|
222
|
+
|
|
223
|
+
elif all_graders_have_weights and not is_format_weight_specified:
|
|
224
|
+
self._normalize_grader_weights()
|
|
225
|
+
|
|
226
|
+
elif no_graders_have_weights:
|
|
227
|
+
self._initialize_grader_weights()
|
|
228
|
+
self._normalize_grader_weights()
|
|
229
|
+
|
|
230
|
+
return self
|
|
231
|
+
|
|
232
|
+
def _validate_weights_sum_to_one(self) -> None:
|
|
233
|
+
"""Validate that format_reward_weight and grader weights sum to 1.0"""
|
|
234
|
+
total_weight = self.format_reward_weight + sum( # type: ignore[operator]
|
|
235
|
+
grader.weight # type: ignore[misc]
|
|
236
|
+
for grader in self.graders
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if abs(total_weight - 1.0) > 1e-10:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"When all weights are explicitly provided, they must sum to 1.0. "
|
|
242
|
+
f"Got format_reward_weight={self.format_reward_weight}, "
|
|
243
|
+
f"graders={self.graders}"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def _normalize_grader_weights(self) -> None:
|
|
247
|
+
"""Normalize only grader weights to fill (1 - format_reward_weight)"""
|
|
248
|
+
total_grader_weight = sum(grader.weight for grader in self.graders) # type: ignore[misc]
|
|
249
|
+
target_grader_total = 1.0 - self.format_reward_weight
|
|
250
|
+
|
|
251
|
+
# only normalize if weights aren't already properly normalized
|
|
252
|
+
if abs(total_grader_weight - target_grader_total) > 1e-10:
|
|
253
|
+
scale_factor = target_grader_total / total_grader_weight
|
|
254
|
+
for grader in self.graders:
|
|
255
|
+
original_weight = grader.weight
|
|
256
|
+
grader.weight *= scale_factor # type: ignore[operator]
|
|
257
|
+
log_info(
|
|
258
|
+
f"{grader.type}'s weight scaled from {original_weight} to {grader.weight:.2f}"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def _initialize_grader_weights(self) -> None:
|
|
262
|
+
"""Initialize all grader weights when none are provided"""
|
|
263
|
+
for grader in self.graders:
|
|
264
|
+
grader.weight = 1.0
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class TrainingConfig(BaseModel):
|
|
268
|
+
# training file ID
|
|
269
|
+
training_files: List[str]
|
|
270
|
+
# base model string
|
|
271
|
+
model: str
|
|
272
|
+
# number of epochs to train for
|
|
273
|
+
n_epochs: int
|
|
274
|
+
# training learning rate
|
|
275
|
+
learning_rate: float
|
|
276
|
+
# number of checkpoints to save
|
|
277
|
+
n_checkpoints: int | None = None
|
|
278
|
+
# training batch size
|
|
279
|
+
batch_size: int = Field(..., ge=1, le=1024)
|
|
280
|
+
# up to 40 character suffix for output model name
|
|
281
|
+
experiment_name: str | None = None
|
|
282
|
+
# sequence length
|
|
283
|
+
max_length: int = 2500
|
|
284
|
+
# # weights & biases api key
|
|
285
|
+
# wandb_key: str | None = None
|
|
286
|
+
# IFT by default
|
|
287
|
+
pre_train: bool = False
|
|
288
|
+
# fine-tune type
|
|
289
|
+
fine_tune_type: FineTuneType = FineTuneType.STANDARD
|
|
290
|
+
# LoRA config
|
|
291
|
+
lora_config: Optional[LoRAConfig] = None
|
|
292
|
+
# reward_components are REINFORCEMENT-specific
|
|
293
|
+
reward_components: Optional[RewardComponents] = None
|
|
294
|
+
|
|
295
|
+
@model_validator(mode="after")
|
|
296
|
+
def validate_reward_components(self) -> "TrainingConfig":
|
|
297
|
+
# TODO: re-enable the below and make reward_components required for REINFORCEMENT. Disabled for now for backwards-compatibility
|
|
298
|
+
# if (
|
|
299
|
+
# self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
|
|
300
|
+
# and not self.reward_components
|
|
301
|
+
# ):
|
|
302
|
+
# raise ValueError("REINFORCEMENT fine-tuning requires reward components")
|
|
303
|
+
if (
|
|
304
|
+
self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
|
|
305
|
+
and not self.reward_components
|
|
306
|
+
):
|
|
307
|
+
self.reward_components = RewardComponents(
|
|
308
|
+
format_reward_weight=0.1,
|
|
309
|
+
graders=[Grader(type=GraderType.MATH_ACCURACY, weight=0.9)],
|
|
310
|
+
)
|
|
311
|
+
if self.fine_tune_type == FineTuneType.STANDARD and self.reward_components:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
"Reward components are incompatible with standard fine-tuning"
|
|
314
|
+
)
|
|
315
|
+
if self.fine_tune_type == FineTuneType.PREFERENCE and self.reward_components:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
"Reward components are incompatible with preference fine-tuning"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
return self
|
|
321
|
+
|
|
322
|
+
@field_validator("fine_tune_type")
|
|
323
|
+
def validate_fine_tune_type(cls, v: Any) -> Any:
|
|
324
|
+
if v == FineTuneType.GRPO:
|
|
325
|
+
warnings.warn(
|
|
326
|
+
"FineTuneType.GRPO is deprecated and will be removed in a future version. Use FineTuneType.REINFORCEMENT",
|
|
327
|
+
DeprecationWarning,
|
|
328
|
+
stacklevel=2,
|
|
329
|
+
)
|
|
330
|
+
return v
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
class AcceleratorType(str, Enum):
|
|
334
|
+
GAUDI2 = "GAUDI2"
|
|
335
|
+
GAUDI3 = "GAUDI3"
|
|
336
|
+
A100 = "A100"
|
|
337
|
+
A10 = "A10"
|
|
338
|
+
H100 = "H100"
|
|
339
|
+
MI300X = "MI300X"
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class InfrastructureConfig(BaseModel):
|
|
343
|
+
accel_type: AcceleratorType
|
|
344
|
+
n_accel: int
|
|
345
|
+
n_node: int = 1
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class FinetuneRequest(BaseModel):
|
|
349
|
+
"""
|
|
350
|
+
Fine-tune request type
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
project_id: int
|
|
354
|
+
training_config: TrainingConfig
|
|
355
|
+
infrastructure_config: InfrastructureConfig
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class FinetuneResponse(BaseModel):
|
|
359
|
+
"""
|
|
360
|
+
Fine-tune API response type
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
# job ID
|
|
364
|
+
id: str | None = None
|
|
365
|
+
# fine-tune type
|
|
366
|
+
fine_tune_type: FineTuneType = FineTuneType.STANDARD
|
|
367
|
+
reward_components: Optional[RewardComponents] = None
|
|
368
|
+
# training file id
|
|
369
|
+
training_files: List[str] | None = None
|
|
370
|
+
# validation file id
|
|
371
|
+
# validation_files: str | None = None TODO
|
|
372
|
+
# base model name
|
|
373
|
+
model: str | None = None
|
|
374
|
+
accel_type: AcceleratorType
|
|
375
|
+
n_accel: int
|
|
376
|
+
n_node: int | None = None
|
|
377
|
+
# number of epochs
|
|
378
|
+
n_epochs: int | None = None
|
|
379
|
+
# number of checkpoints to save
|
|
380
|
+
# n_checkpoints: int | None = None # TODO
|
|
381
|
+
# training batch size
|
|
382
|
+
batch_size: int | None = None
|
|
383
|
+
# training learning rate
|
|
384
|
+
learning_rate: float | None = None
|
|
385
|
+
# LoRA configuration returned when LoRA fine-tuning is enabled
|
|
386
|
+
lora_config: Optional[LoRAConfig] = None
|
|
387
|
+
# number of steps between evals
|
|
388
|
+
# eval_steps: int | None = None TODO
|
|
389
|
+
# created/updated datetime stamps
|
|
390
|
+
created_at: datetime | None = None
|
|
391
|
+
# updated_at: str | None = None
|
|
392
|
+
# up to 40 character suffix for output model name
|
|
393
|
+
experiment_name: str | None = None
|
|
394
|
+
# job status
|
|
395
|
+
status: FinetuneJobStatus | None = None
|
|
396
|
+
deleted_at: datetime | None = None
|
|
397
|
+
|
|
398
|
+
# list of fine-tune events
|
|
399
|
+
events: List[FinetuneEvent] | None = None
|
|
400
|
+
inference_available: bool = False
|
|
401
|
+
project_id: Optional[int] = None # TODO - fix this
|
|
402
|
+
completed_at: datetime | None = None
|
|
403
|
+
description: str | None = None
|
|
404
|
+
|
|
405
|
+
# dataset token count
|
|
406
|
+
# TODO
|
|
407
|
+
# token_count: int | None = None
|
|
408
|
+
# # model parameter count
|
|
409
|
+
# param_count: int | None = None
|
|
410
|
+
# # fine-tune job price
|
|
411
|
+
# total_price: int | None = None
|
|
412
|
+
# # number of epochs completed (incrementing counter)
|
|
413
|
+
# epochs_completed: int | None = None
|
|
414
|
+
# # place in job queue (decrementing counter)
|
|
415
|
+
# queue_depth: int | None = None
|
|
416
|
+
# # weights & biases project name
|
|
417
|
+
# wandb_project_name: str | None = None
|
|
418
|
+
# # weights & biases job url
|
|
419
|
+
# wandb_url: str | None = None
|
|
420
|
+
# # training file metadata
|
|
421
|
+
# training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
|
|
422
|
+
# training_file_size: int | None = Field(None, alias="TrainingFileSize")
|
|
423
|
+
|
|
424
|
+
@model_serializer(mode="wrap")
|
|
425
|
+
def serialize_model(
|
|
426
|
+
self, handler: Callable[[Any], dict[str, Any]]
|
|
427
|
+
) -> dict[str, Any]:
|
|
428
|
+
# Remove 'reward_components' if it's None
|
|
429
|
+
dump_dict = handler(self)
|
|
430
|
+
if dump_dict.get("reward_components") is None:
|
|
431
|
+
del dump_dict["reward_components"]
|
|
432
|
+
return dump_dict
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class FinetuneList(BaseModel):
|
|
436
|
+
# object type
|
|
437
|
+
object: Literal["list"] | None = None
|
|
438
|
+
# list of fine-tune job objects
|
|
439
|
+
data: List[FinetuneResponse] | None = None
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class FinetuneListEvents(BaseModel):
|
|
443
|
+
# object type
|
|
444
|
+
object: Literal["list"] | None = None
|
|
445
|
+
# list of fine-tune events
|
|
446
|
+
data: List[FinetuneEvent] | None = None
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class FinetuneDownloadResult(BaseModel):
|
|
450
|
+
# object type
|
|
451
|
+
object: Literal["local"] | None = None
|
|
452
|
+
# fine-tune job id
|
|
453
|
+
id: str | None = None
|
|
454
|
+
# checkpoint step number
|
|
455
|
+
checkpoint_step: int | None = None
|
|
456
|
+
# local path filename
|
|
457
|
+
filename: str | None = None
|
|
458
|
+
# size in bytes
|
|
459
|
+
size: int | None = None
|
|
@@ -62,6 +62,15 @@ class CreateRunPython(CreateTool):
|
|
|
62
62
|
config: RunPythonConfig
|
|
63
63
|
|
|
64
64
|
|
|
65
|
+
class AgentAsToolConfig(ToolConfig):
|
|
66
|
+
agent_id: str
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class CreateAgentAsTool(CreateTool):
|
|
70
|
+
type: Literal[ToolType.AGENT_AS_TOOL] = ToolType.AGENT_AS_TOOL
|
|
71
|
+
config: AgentAsToolConfig
|
|
72
|
+
|
|
73
|
+
|
|
65
74
|
class Tool(BaseModel):
|
|
66
75
|
type: ToolType
|
|
67
76
|
name: Annotated[str, AfterValidator(validate_length)]
|
|
@@ -88,13 +97,19 @@ class RunPythonTool(Tool):
|
|
|
88
97
|
config: RunPythonConfig
|
|
89
98
|
|
|
90
99
|
|
|
100
|
+
class AgentAsTool(Tool):
|
|
101
|
+
type: Literal[ToolType.AGENT_AS_TOOL] = ToolType.AGENT_AS_TOOL
|
|
102
|
+
config: AgentAsToolConfig
|
|
103
|
+
|
|
104
|
+
|
|
91
105
|
CreateToolRequest = Annotated[
|
|
92
|
-
Union[CreateFileSearch, CreateRunPython, CreateWebSearch],
|
|
106
|
+
Union[CreateFileSearch, CreateRunPython, CreateWebSearch, CreateAgentAsTool],
|
|
93
107
|
Field(discriminator="type"),
|
|
94
108
|
]
|
|
95
109
|
|
|
96
110
|
ToolResponse = Annotated[
|
|
97
|
-
Union[FileSearchTool, WebSearchTool, RunPythonTool],
|
|
111
|
+
Union[FileSearchTool, WebSearchTool, RunPythonTool, AgentAsTool],
|
|
112
|
+
Field(discriminator="type"),
|
|
98
113
|
]
|
|
99
114
|
|
|
100
115
|
|
|
@@ -110,6 +125,10 @@ class UpdateRunPython(CreateRunPython):
|
|
|
110
125
|
pass
|
|
111
126
|
|
|
112
127
|
|
|
128
|
+
class UpdateAgentAsTool(CreateAgentAsTool):
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
|
|
113
132
|
UpdateToolRequest = Annotated[
|
|
114
133
|
Union[UpdateFileSearch, UpdateWebSearch, UpdateRunPython],
|
|
115
134
|
Field(discriminator="type"),
|
|
@@ -117,7 +136,7 @@ UpdateToolRequest = Annotated[
|
|
|
117
136
|
|
|
118
137
|
|
|
119
138
|
class GetToolsResponse(BaseModel):
|
|
120
|
-
"""Response schema for paginated tool list."""
|
|
139
|
+
"""Response schema for a paginated tool list."""
|
|
121
140
|
|
|
122
141
|
data: list[ToolResponse]
|
|
123
142
|
total: Optional[int] = None
|
|
@@ -1,256 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from datetime import datetime
|
|
4
|
-
from enum import Enum
|
|
5
|
-
from typing import Any, Dict, List, Literal, Optional
|
|
6
|
-
|
|
7
|
-
from pydantic import Field
|
|
8
|
-
|
|
9
|
-
from seekrai.types.abstract import BaseModel
|
|
10
|
-
from seekrai.types.common import (
|
|
11
|
-
ObjectType,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class FinetuneJobStatus(str, Enum):
|
|
16
|
-
"""
|
|
17
|
-
Possible fine-tune job status
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
STATUS_PENDING = "pending"
|
|
21
|
-
STATUS_QUEUED = "queued"
|
|
22
|
-
STATUS_RUNNING = "running"
|
|
23
|
-
# STATUS_COMPRESSING = "compressing"
|
|
24
|
-
# STATUS_UPLOADING = "uploading"
|
|
25
|
-
STATUS_CANCEL_REQUESTED = "cancel_requested"
|
|
26
|
-
STATUS_CANCELLED = "cancelled"
|
|
27
|
-
STATUS_FAILED = "failed"
|
|
28
|
-
STATUS_COMPLETED = "completed"
|
|
29
|
-
STATUS_DELETED = "deleted"
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class FinetuneEventLevels(str, Enum):
|
|
33
|
-
"""
|
|
34
|
-
Fine-tune job event status levels
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
NULL = ""
|
|
38
|
-
INFO = "Info"
|
|
39
|
-
WARNING = "Warning"
|
|
40
|
-
ERROR = "Error"
|
|
41
|
-
LEGACY_INFO = "info"
|
|
42
|
-
LEGACY_IWARNING = "warning"
|
|
43
|
-
LEGACY_IERROR = "error"
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class FinetuneEventType(str, Enum):
|
|
47
|
-
"""
|
|
48
|
-
Fine-tune job event types
|
|
49
|
-
"""
|
|
50
|
-
|
|
51
|
-
JOB_PENDING = "JOB_PENDING"
|
|
52
|
-
JOB_START = "JOB_START"
|
|
53
|
-
JOB_STOPPED = "JOB_STOPPED"
|
|
54
|
-
MODEL_DOWNLOADING = "MODEL_DOWNLOADING"
|
|
55
|
-
MODEL_DOWNLOAD_COMPLETE = "MODEL_DOWNLOAD_COMPLETE"
|
|
56
|
-
TRAINING_DATA_DOWNLOADING = "TRAINING_DATA_DOWNLOADING"
|
|
57
|
-
TRAINING_DATA_DOWNLOAD_COMPLETE = "TRAINING_DATA_DOWNLOAD_COMPLETE"
|
|
58
|
-
VALIDATION_DATA_DOWNLOADING = "VALIDATION_DATA_DOWNLOADING"
|
|
59
|
-
VALIDATION_DATA_DOWNLOAD_COMPLETE = "VALIDATION_DATA_DOWNLOAD_COMPLETE"
|
|
60
|
-
WANDB_INIT = "WANDB_INIT"
|
|
61
|
-
TRAINING_START = "TRAINING_START"
|
|
62
|
-
CHECKPOINT_SAVE = "CHECKPOINT_SAVE"
|
|
63
|
-
BILLING_LIMIT = "BILLING_LIMIT"
|
|
64
|
-
EPOCH_COMPLETE = "EPOCH_COMPLETE"
|
|
65
|
-
TRAINING_COMPLETE = "TRAINING_COMPLETE"
|
|
66
|
-
MODEL_COMPRESSING = "COMPRESSING_MODEL"
|
|
67
|
-
MODEL_COMPRESSION_COMPLETE = "MODEL_COMPRESSION_COMPLETE"
|
|
68
|
-
MODEL_UPLOADING = "MODEL_UPLOADING"
|
|
69
|
-
MODEL_UPLOAD_COMPLETE = "MODEL_UPLOAD_COMPLETE"
|
|
70
|
-
JOB_COMPLETE = "JOB_COMPLETE"
|
|
71
|
-
JOB_ERROR = "JOB_ERROR"
|
|
72
|
-
CANCEL_REQUESTED = "CANCEL_REQUESTED"
|
|
73
|
-
JOB_RESTARTED = "JOB_RESTARTED"
|
|
74
|
-
REFUND = "REFUND"
|
|
75
|
-
WARNING = "WARNING"
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
class FineTuneType(str, Enum):
|
|
79
|
-
STANDARD = "STANDARD"
|
|
80
|
-
PREFERENCE = "PREFERENCE"
|
|
81
|
-
GRPO = "GRPO"
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
class FinetuneEvent(BaseModel):
|
|
85
|
-
"""
|
|
86
|
-
Fine-tune event type
|
|
87
|
-
"""
|
|
88
|
-
|
|
89
|
-
# object type
|
|
90
|
-
object: Literal[ObjectType.FinetuneEvent]
|
|
91
|
-
# created at datetime stamp
|
|
92
|
-
created_at: datetime | None = None
|
|
93
|
-
# metrics that we expose
|
|
94
|
-
loss: float | None = None
|
|
95
|
-
epoch: float | None = None
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
class LoRAConfig(BaseModel):
|
|
99
|
-
r: int = Field(8, gt=0, description="Rank of the update matrices.")
|
|
100
|
-
alpha: int = Field(32, gt=0, description="Scaling factor applied to LoRA updates.")
|
|
101
|
-
dropout: float = Field(
|
|
102
|
-
0.1,
|
|
103
|
-
ge=0.0,
|
|
104
|
-
le=1.0,
|
|
105
|
-
description="Fraction of LoRA neurons dropped during training.",
|
|
106
|
-
)
|
|
107
|
-
bias: Literal["none", "all", "lora_only"] = Field(
|
|
108
|
-
"none",
|
|
109
|
-
description="Bias terms to train; choose from 'none', 'all', or 'lora_only'.",
|
|
110
|
-
)
|
|
111
|
-
extras: Dict[str, Any] = Field(default_factory=dict)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
class TrainingConfig(BaseModel):
|
|
115
|
-
# training file ID
|
|
116
|
-
training_files: List[str]
|
|
117
|
-
# base model string
|
|
118
|
-
model: str
|
|
119
|
-
# number of epochs to train for
|
|
120
|
-
n_epochs: int
|
|
121
|
-
# training learning rate
|
|
122
|
-
learning_rate: float
|
|
123
|
-
# number of checkpoints to save
|
|
124
|
-
n_checkpoints: int | None = None
|
|
125
|
-
# training batch size
|
|
126
|
-
batch_size: int = Field(..., ge=1, le=1024)
|
|
127
|
-
# up to 40 character suffix for output model name
|
|
128
|
-
experiment_name: str | None = None
|
|
129
|
-
# sequence length
|
|
130
|
-
max_length: int = 2500
|
|
131
|
-
# # weights & biases api key
|
|
132
|
-
# wandb_key: str | None = None
|
|
133
|
-
# IFT by default
|
|
134
|
-
pre_train: bool = False
|
|
135
|
-
# fine-tune type
|
|
136
|
-
fine_tune_type: FineTuneType = FineTuneType.STANDARD
|
|
137
|
-
# LoRA config
|
|
138
|
-
lora_config: Optional[LoRAConfig] = None
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
class AcceleratorType(str, Enum):
|
|
142
|
-
GAUDI2 = "GAUDI2"
|
|
143
|
-
GAUDI3 = "GAUDI3"
|
|
144
|
-
A100 = "A100"
|
|
145
|
-
A10 = "A10"
|
|
146
|
-
H100 = "H100"
|
|
147
|
-
MI300X = "MI300X"
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
class InfrastructureConfig(BaseModel):
|
|
151
|
-
accel_type: AcceleratorType
|
|
152
|
-
n_accel: int
|
|
153
|
-
n_node: int = 1
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
class FinetuneRequest(BaseModel):
|
|
157
|
-
"""
|
|
158
|
-
Fine-tune request type
|
|
159
|
-
"""
|
|
160
|
-
|
|
161
|
-
project_id: int
|
|
162
|
-
training_config: TrainingConfig
|
|
163
|
-
infrastructure_config: InfrastructureConfig
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
class FinetuneResponse(BaseModel):
|
|
167
|
-
"""
|
|
168
|
-
Fine-tune API response type
|
|
169
|
-
"""
|
|
170
|
-
|
|
171
|
-
# job ID
|
|
172
|
-
id: str | None = None
|
|
173
|
-
# fine-tune type
|
|
174
|
-
fine_tune_type: FineTuneType = FineTuneType.STANDARD
|
|
175
|
-
# training file id
|
|
176
|
-
training_files: List[str] | None = None
|
|
177
|
-
# validation file id
|
|
178
|
-
# validation_files: str | None = None TODO
|
|
179
|
-
# base model name
|
|
180
|
-
model: str | None = None
|
|
181
|
-
accel_type: AcceleratorType
|
|
182
|
-
n_accel: int
|
|
183
|
-
n_node: int | None = None
|
|
184
|
-
# number of epochs
|
|
185
|
-
n_epochs: int | None = None
|
|
186
|
-
# number of checkpoints to save
|
|
187
|
-
# n_checkpoints: int | None = None # TODO
|
|
188
|
-
# training batch size
|
|
189
|
-
batch_size: int | None = None
|
|
190
|
-
# training learning rate
|
|
191
|
-
learning_rate: float | None = None
|
|
192
|
-
# LoRA configuration returned when LoRA fine-tuning is enabled
|
|
193
|
-
lora_config: Optional[LoRAConfig] = None
|
|
194
|
-
# number of steps between evals
|
|
195
|
-
# eval_steps: int | None = None TODO
|
|
196
|
-
# created/updated datetime stamps
|
|
197
|
-
created_at: datetime | None = None
|
|
198
|
-
# updated_at: str | None = None
|
|
199
|
-
# up to 40 character suffix for output model name
|
|
200
|
-
experiment_name: str | None = None
|
|
201
|
-
# job status
|
|
202
|
-
status: FinetuneJobStatus | None = None
|
|
203
|
-
deleted_at: datetime | None = None
|
|
204
|
-
|
|
205
|
-
# list of fine-tune events
|
|
206
|
-
events: List[FinetuneEvent] | None = None
|
|
207
|
-
inference_available: bool = False
|
|
208
|
-
project_id: Optional[int] = None # TODO - fix this
|
|
209
|
-
completed_at: datetime | None = None
|
|
210
|
-
description: str | None = None
|
|
211
|
-
|
|
212
|
-
# dataset token count
|
|
213
|
-
# TODO
|
|
214
|
-
# token_count: int | None = None
|
|
215
|
-
# # model parameter count
|
|
216
|
-
# param_count: int | None = None
|
|
217
|
-
# # fine-tune job price
|
|
218
|
-
# total_price: int | None = None
|
|
219
|
-
# # number of epochs completed (incrementing counter)
|
|
220
|
-
# epochs_completed: int | None = None
|
|
221
|
-
# # place in job queue (decrementing counter)
|
|
222
|
-
# queue_depth: int | None = None
|
|
223
|
-
# # weights & biases project name
|
|
224
|
-
# wandb_project_name: str | None = None
|
|
225
|
-
# # weights & biases job url
|
|
226
|
-
# wandb_url: str | None = None
|
|
227
|
-
# # training file metadata
|
|
228
|
-
# training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
|
|
229
|
-
# training_file_size: int | None = Field(None, alias="TrainingFileSize")
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
class FinetuneList(BaseModel):
|
|
233
|
-
# object type
|
|
234
|
-
object: Literal["list"] | None = None
|
|
235
|
-
# list of fine-tune job objects
|
|
236
|
-
data: List[FinetuneResponse] | None = None
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
class FinetuneListEvents(BaseModel):
|
|
240
|
-
# object type
|
|
241
|
-
object: Literal["list"] | None = None
|
|
242
|
-
# list of fine-tune events
|
|
243
|
-
data: List[FinetuneEvent] | None = None
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
class FinetuneDownloadResult(BaseModel):
|
|
247
|
-
# object type
|
|
248
|
-
object: Literal["local"] | None = None
|
|
249
|
-
# fine-tune job id
|
|
250
|
-
id: str | None = None
|
|
251
|
-
# checkpoint step number
|
|
252
|
-
checkpoint_step: int | None = None
|
|
253
|
-
# local path filename
|
|
254
|
-
filename: str | None = None
|
|
255
|
-
# size in bytes
|
|
256
|
-
size: int | None = None
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|