seekrai 0.5.26__tar.gz → 0.5.29__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. {seekrai-0.5.26 → seekrai-0.5.29}/PKG-INFO +1 -1
  2. {seekrai-0.5.26 → seekrai-0.5.29}/pyproject.toml +1 -1
  3. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/__init__.py +11 -0
  4. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/__init__.py +2 -0
  5. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/__init__.py +2 -0
  6. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/__init__.py +2 -0
  7. seekrai-0.5.29/src/seekrai/types/agents/tools/schemas/agent_as_tool.py +16 -0
  8. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/tool_types.py +5 -1
  9. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/enums.py +1 -0
  10. seekrai-0.5.29/src/seekrai/types/finetune.py +459 -0
  11. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/tools.py +22 -3
  12. seekrai-0.5.26/src/seekrai/types/finetune.py +0 -256
  13. {seekrai-0.5.26 → seekrai-0.5.29}/LICENSE +0 -0
  14. {seekrai-0.5.26 → seekrai-0.5.29}/README.md +0 -0
  15. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/__init__.py +0 -0
  16. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/abstract/__init__.py +0 -0
  17. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/abstract/api_requestor.py +0 -0
  18. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/abstract/response_parsing.py +0 -0
  19. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/client.py +0 -0
  20. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/constants.py +0 -0
  21. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/error.py +0 -0
  22. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/filemanager.py +0 -0
  23. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/__init__.py +0 -0
  24. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/__init__.py +0 -0
  25. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/agent_inference.py +0 -0
  26. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/agent_observability.py +0 -0
  27. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/agents.py +0 -0
  28. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/python_functions.py +0 -0
  29. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/agents/threads.py +0 -0
  30. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/alignment.py +0 -0
  31. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/chat/__init__.py +0 -0
  32. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/chat/completions.py +0 -0
  33. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/completions.py +0 -0
  34. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/deployments.py +0 -0
  35. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/embeddings.py +0 -0
  36. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/explainability.py +0 -0
  37. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/files.py +0 -0
  38. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/finetune.py +0 -0
  39. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/images.py +0 -0
  40. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/ingestion.py +0 -0
  41. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/models.py +0 -0
  42. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/projects.py +0 -0
  43. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/resource_base.py +0 -0
  44. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/tools.py +0 -0
  45. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/resources/vectordb.py +0 -0
  46. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/seekrflow_response.py +0 -0
  47. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/abstract.py +0 -0
  48. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/agent.py +0 -0
  49. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/observability.py +0 -0
  50. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/python_functions.py +0 -0
  51. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/runs.py +0 -0
  52. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/threads.py +0 -0
  53. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/env_model_config.py +0 -0
  54. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/file_search.py +0 -0
  55. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/file_search_env.py +0 -0
  56. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/run_python.py +0 -0
  57. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/run_python_env.py +0 -0
  58. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/web_search.py +0 -0
  59. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/schemas/web_search_env.py +0 -0
  60. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/agents/tools/tool.py +0 -0
  61. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/alignment.py +0 -0
  62. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/chat_completions.py +0 -0
  63. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/common.py +0 -0
  64. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/completions.py +0 -0
  65. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/deployments.py +0 -0
  66. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/embeddings.py +0 -0
  67. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/error.py +0 -0
  68. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/explainability.py +0 -0
  69. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/files.py +0 -0
  70. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/images.py +0 -0
  71. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/ingestion.py +0 -0
  72. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/models.py +0 -0
  73. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/projects.py +0 -0
  74. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/types/vectordb.py +0 -0
  75. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/__init__.py +0 -0
  76. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/_log.py +0 -0
  77. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/api_helpers.py +0 -0
  78. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/files.py +0 -0
  79. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/utils/tools.py +0 -0
  80. {seekrai-0.5.26 → seekrai-0.5.29}/src/seekrai/version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: seekrai
3
- Version: 0.5.26
3
+ Version: 0.5.29
4
4
  Summary: Python client for SeekrAI
5
5
  License: Apache-2.0
6
6
  License-File: LICENSE
@@ -13,7 +13,7 @@ build-backend = "poetry.core.masonry.api"
13
13
 
14
14
  [tool.poetry]
15
15
  name = "seekrai"
16
- version = "0.5.26"
16
+ version = "0.5.29"
17
17
  authors = [
18
18
  "SeekrFlow <support@seekr.com>"
19
19
  ]
@@ -1,6 +1,7 @@
1
1
  from seekrai.types.abstract import SeekrFlowClient
2
2
  from seekrai.types.agents import (
3
3
  Agent,
4
+ AgentAsToolLegacy,
4
5
  AgentDeleteResponse,
5
6
  CreateAgentRequest,
6
7
  EnvConfig,
@@ -47,6 +48,7 @@ from seekrai.types.agents import (
47
48
  WebSearchEnv,
48
49
  )
49
50
  from seekrai.types.agents.tools.schemas import (
51
+ AgentAsToolLegacy,
50
52
  FileSearch,
51
53
  FileSearchEnv,
52
54
  RunPython,
@@ -120,6 +122,9 @@ from seekrai.types.projects import (
120
122
  ProjectWithRuns,
121
123
  )
122
124
  from seekrai.types.tools import (
125
+ AgentAsTool,
126
+ AgentAsToolConfig,
127
+ CreateAgentAsTool,
123
128
  CreateFileSearch,
124
129
  CreateRunPython,
125
130
  CreateToolRequest,
@@ -132,6 +137,7 @@ from seekrai.types.tools import (
132
137
  ToolAgentSummaryResponse,
133
138
  ToolDeleteResponse,
134
139
  ToolResponse,
140
+ UpdateAgentAsTool,
135
141
  UpdateFileSearch,
136
142
  UpdateRunPython,
137
143
  UpdateToolRequest,
@@ -258,4 +264,9 @@ __all__ = [
258
264
  "RunPythonEnv",
259
265
  "WebSearch",
260
266
  "WebSearchEnv",
267
+ "AgentAsToolLegacy",
268
+ "AgentAsTool",
269
+ "CreateAgentAsTool",
270
+ "AgentAsToolConfig",
271
+ "UpdateAgentAsTool",
261
272
  ]
@@ -51,6 +51,7 @@ from seekrai.types.agents.tools import (
51
51
  ToolBase,
52
52
  )
53
53
  from seekrai.types.agents.tools.schemas import (
54
+ AgentAsToolLegacy,
54
55
  FileSearch,
55
56
  FileSearchEnv,
56
57
  RunPython,
@@ -106,6 +107,7 @@ __all__ = [
106
107
  "RunPythonEnv",
107
108
  "WebSearch",
108
109
  "WebSearchEnv",
110
+ "AgentAsToolLegacy",
109
111
  "PythonFunctionBase",
110
112
  "PythonFunctionResponse",
111
113
  "DeletePythonFunctionResponse",
@@ -1,5 +1,6 @@
1
1
  from seekrai.types.agents.tools.env_model_config import EnvConfig
2
2
  from seekrai.types.agents.tools.schemas import (
3
+ AgentAsToolLegacy,
3
4
  FileSearch,
4
5
  FileSearchEnv,
5
6
  RunPython,
@@ -23,4 +24,5 @@ __all__ = [
23
24
  "RunPythonEnv",
24
25
  "WebSearch",
25
26
  "WebSearchEnv",
27
+ "AgentAsToolLegacy",
26
28
  ]
@@ -1,3 +1,4 @@
1
+ from seekrai.types.agents.tools.schemas.agent_as_tool import AgentAsToolLegacy
1
2
  from seekrai.types.agents.tools.schemas.file_search import FileSearch
2
3
  from seekrai.types.agents.tools.schemas.file_search_env import FileSearchEnv
3
4
  from seekrai.types.agents.tools.schemas.run_python import RunPython
@@ -13,4 +14,5 @@ __all__ = [
13
14
  "RunPythonEnv",
14
15
  "WebSearch",
15
16
  "WebSearchEnv",
17
+ "AgentAsToolLegacy",
16
18
  ]
@@ -0,0 +1,16 @@
1
+ from typing import Literal
2
+
3
+ from seekrai.types.agents.tools import EnvConfig
4
+ from seekrai.types.agents.tools.tool import ToolBase, ToolType
5
+
6
+
7
+ class AgentAsToolLegacy(ToolBase[Literal["agent_as_tool"], EnvConfig]):
8
+ name: Literal["agent_as_tool"] = ToolType.AGENT_AS_TOOL.value
9
+ description: str
10
+ agent_id: str
11
+
12
+ model_config = {
13
+ "json_schema_extra": {
14
+ "deprecated": True,
15
+ }
16
+ }
@@ -2,9 +2,13 @@ from typing import Annotated, Union
2
2
 
3
3
  from pydantic import Field
4
4
 
5
+ from seekrai.types.agents.tools.schemas.agent_as_tool import AgentAsToolLegacy
5
6
  from seekrai.types.agents.tools.schemas.file_search import FileSearch
6
7
  from seekrai.types.agents.tools.schemas.run_python import RunPython
7
8
  from seekrai.types.agents.tools.schemas.web_search import WebSearch
8
9
 
9
10
 
10
- Tool = Annotated[Union[FileSearch, WebSearch, RunPython], Field(discriminator="name")]
11
+ Tool = Annotated[
12
+ Union[FileSearch, WebSearch, RunPython, AgentAsToolLegacy],
13
+ Field(discriminator="name"),
14
+ ]
@@ -28,3 +28,4 @@ class ToolType(str, Enum):
28
28
  FILE_SEARCH = "file_search"
29
29
  WEB_SEARCH = "web_search"
30
30
  RUN_PYTHON = "run_python"
31
+ AGENT_AS_TOOL = "agent_as_tool"
@@ -0,0 +1,459 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from datetime import datetime
5
+ from enum import Enum
6
+ from typing import Any, Callable, Dict, List, Literal, Optional
7
+
8
+ from pydantic import Field, field_validator, model_serializer, model_validator
9
+
10
+ from seekrai.types.abstract import BaseModel
11
+ from seekrai.types.common import (
12
+ ObjectType,
13
+ )
14
+ from seekrai.utils._log import log_info
15
+
16
+
17
+ class FinetuneJobStatus(str, Enum):
18
+ """
19
+ Possible fine-tune job status
20
+ """
21
+
22
+ STATUS_PENDING = "pending"
23
+ STATUS_QUEUED = "queued"
24
+ STATUS_RUNNING = "running"
25
+ # STATUS_COMPRESSING = "compressing"
26
+ # STATUS_UPLOADING = "uploading"
27
+ STATUS_CANCEL_REQUESTED = "cancel_requested"
28
+ STATUS_CANCELLED = "cancelled"
29
+ STATUS_FAILED = "failed"
30
+ STATUS_COMPLETED = "completed"
31
+ STATUS_DELETED = "deleted"
32
+
33
+
34
+ class FinetuneEventLevels(str, Enum):
35
+ """
36
+ Fine-tune job event status levels
37
+ """
38
+
39
+ NULL = ""
40
+ INFO = "Info"
41
+ WARNING = "Warning"
42
+ ERROR = "Error"
43
+ LEGACY_INFO = "info"
44
+ LEGACY_IWARNING = "warning"
45
+ LEGACY_IERROR = "error"
46
+
47
+
48
+ class FinetuneEventType(str, Enum):
49
+ """
50
+ Fine-tune job event types
51
+ """
52
+
53
+ JOB_PENDING = "JOB_PENDING"
54
+ JOB_START = "JOB_START"
55
+ JOB_STOPPED = "JOB_STOPPED"
56
+ MODEL_DOWNLOADING = "MODEL_DOWNLOADING"
57
+ MODEL_DOWNLOAD_COMPLETE = "MODEL_DOWNLOAD_COMPLETE"
58
+ TRAINING_DATA_DOWNLOADING = "TRAINING_DATA_DOWNLOADING"
59
+ TRAINING_DATA_DOWNLOAD_COMPLETE = "TRAINING_DATA_DOWNLOAD_COMPLETE"
60
+ VALIDATION_DATA_DOWNLOADING = "VALIDATION_DATA_DOWNLOADING"
61
+ VALIDATION_DATA_DOWNLOAD_COMPLETE = "VALIDATION_DATA_DOWNLOAD_COMPLETE"
62
+ WANDB_INIT = "WANDB_INIT"
63
+ TRAINING_START = "TRAINING_START"
64
+ CHECKPOINT_SAVE = "CHECKPOINT_SAVE"
65
+ BILLING_LIMIT = "BILLING_LIMIT"
66
+ EPOCH_COMPLETE = "EPOCH_COMPLETE"
67
+ TRAINING_COMPLETE = "TRAINING_COMPLETE"
68
+ MODEL_COMPRESSING = "COMPRESSING_MODEL"
69
+ MODEL_COMPRESSION_COMPLETE = "MODEL_COMPRESSION_COMPLETE"
70
+ MODEL_UPLOADING = "MODEL_UPLOADING"
71
+ MODEL_UPLOAD_COMPLETE = "MODEL_UPLOAD_COMPLETE"
72
+ JOB_COMPLETE = "JOB_COMPLETE"
73
+ JOB_ERROR = "JOB_ERROR"
74
+ CANCEL_REQUESTED = "CANCEL_REQUESTED"
75
+ JOB_RESTARTED = "JOB_RESTARTED"
76
+ REFUND = "REFUND"
77
+ WARNING = "WARNING"
78
+
79
+
80
+ class FineTuneType(str, Enum):
81
+ STANDARD = "STANDARD"
82
+ GRPO = "GRPO" # deprecated
83
+ PREFERENCE = "PREFERENCE"
84
+ REINFORCEMENT = "REINFORCEMENT"
85
+
86
+
87
+ class GraderType(str, Enum):
88
+ FORMAT_CHECK = "format_check"
89
+ MATH_ACCURACY = "math_accuracy"
90
+ STRING_CHECK = "string_check"
91
+ TEXT_SIMILARITY = "text_similarity"
92
+
93
+
94
+ class StringOperation(str, Enum):
95
+ EQUALS = "equals"
96
+ NOT_EQUALS = "not_equals"
97
+ CONTAINS = "contains"
98
+ CASE_INSENSITIVE_CONTAINS = "case_insensitive_contains"
99
+
100
+
101
+ class TextSimilarityOperation(str, Enum):
102
+ BLEU = "bleu"
103
+ ROUGE = "rouge"
104
+
105
+
106
+ class FinetuneEvent(BaseModel):
107
+ """
108
+ Fine-tune event type
109
+ """
110
+
111
+ # object type
112
+ object: Literal[ObjectType.FinetuneEvent]
113
+ # created at datetime stamp
114
+ created_at: datetime | None = None
115
+ # metrics that we expose
116
+ loss: float | None = None
117
+ epoch: float | None = None
118
+ reward: float | None = None
119
+
120
+ @model_serializer(mode="wrap")
121
+ def serialize_model(
122
+ self, handler: Callable[[Any], dict[str, Any]]
123
+ ) -> dict[str, Any]:
124
+ # Remove 'reward' if it's None
125
+ dump_dict = handler(self)
126
+ if dump_dict.get("reward") is None:
127
+ del dump_dict["reward"]
128
+ return dump_dict
129
+
130
+
131
+ class LoRAConfig(BaseModel):
132
+ r: int = Field(8, gt=0, description="Rank of the update matrices.")
133
+ alpha: int = Field(32, gt=0, description="Scaling factor applied to LoRA updates.")
134
+ dropout: float = Field(
135
+ 0.1,
136
+ ge=0.0,
137
+ le=1.0,
138
+ description="Fraction of LoRA neurons dropped during training.",
139
+ )
140
+ bias: Literal["none", "all", "lora_only"] = Field(
141
+ "none",
142
+ description="Bias terms to train; choose from 'none', 'all', or 'lora_only'.",
143
+ )
144
+ extras: Dict[str, Any] = Field(default_factory=dict)
145
+
146
+
147
+ class Grader(BaseModel):
148
+ type: GraderType
149
+ weight: float | None = Field(default=None, gt=0.0, le=1.0)
150
+ operation: StringOperation | TextSimilarityOperation | None = Field(default=None)
151
+
152
+ @model_validator(mode="before")
153
+ @classmethod
154
+ def validate_operation(cls, data: Any) -> Any:
155
+ if not isinstance(data, dict):
156
+ return data
157
+
158
+ grader_type = data.get("type")
159
+ operation_value = data.get("operation")
160
+
161
+ if grader_type == GraderType.STRING_CHECK:
162
+ if not operation_value:
163
+ raise ValueError(
164
+ "string_check grader is missing required StringOperation"
165
+ )
166
+ if isinstance(operation_value, str):
167
+ try:
168
+ # Convert to enum to validate it's a valid value
169
+ data["operation"] = StringOperation(operation_value.lower())
170
+ except ValueError:
171
+ raise ValueError(
172
+ f"Invalid operation for string_check grader: "
173
+ f"expected StringOperation, but got type '{type(operation_value).__name__}' with value '{operation_value}'"
174
+ )
175
+ elif grader_type == GraderType.TEXT_SIMILARITY:
176
+ if not operation_value:
177
+ raise ValueError(
178
+ "text_similarity grader is missing required TextSimilarityOperation"
179
+ )
180
+ if isinstance(operation_value, str):
181
+ try:
182
+ data["operation"] = TextSimilarityOperation(operation_value.lower())
183
+ except ValueError:
184
+ raise ValueError(
185
+ f"Invalid operation for text_similarity grader: "
186
+ f"expected TextSimilarityOperation, got type '{type(operation_value).__name__}' with value '{operation_value}'"
187
+ )
188
+
189
+ elif grader_type in (GraderType.FORMAT_CHECK, GraderType.MATH_ACCURACY):
190
+ if operation_value:
191
+ raise ValueError(f"{grader_type} grader cannot have an operation")
192
+ data["operation"] = None
193
+
194
+ return data
195
+
196
+
197
+ class RewardComponents(BaseModel):
198
+ format_reward_weight: float = Field(default=0.1, gt=0.0, le=1.0)
199
+ graders: list[Grader] = Field(min_length=1)
200
+
201
+ @model_validator(mode="after")
202
+ def validate_weights(self) -> "RewardComponents":
203
+ is_format_weight_specified = "format_reward_weight" in self.model_fields_set
204
+
205
+ grader_weights_specified = [
206
+ grader.weight is not None for grader in self.graders
207
+ ]
208
+
209
+ all_graders_have_weights = all(grader_weights_specified)
210
+ some_graders_have_weights = any(grader_weights_specified) and not all(
211
+ grader_weights_specified
212
+ )
213
+ no_graders_have_weights = not any(grader_weights_specified)
214
+
215
+ if some_graders_have_weights:
216
+ raise ValueError(
217
+ "Only some graders have weights specified. Either all graders must have weights specified, or none of them."
218
+ )
219
+
220
+ if all_graders_have_weights and is_format_weight_specified:
221
+ self._validate_weights_sum_to_one()
222
+
223
+ elif all_graders_have_weights and not is_format_weight_specified:
224
+ self._normalize_grader_weights()
225
+
226
+ elif no_graders_have_weights:
227
+ self._initialize_grader_weights()
228
+ self._normalize_grader_weights()
229
+
230
+ return self
231
+
232
+ def _validate_weights_sum_to_one(self) -> None:
233
+ """Validate that format_reward_weight and grader weights sum to 1.0"""
234
+ total_weight = self.format_reward_weight + sum( # type: ignore[operator]
235
+ grader.weight # type: ignore[misc]
236
+ for grader in self.graders
237
+ )
238
+
239
+ if abs(total_weight - 1.0) > 1e-10:
240
+ raise ValueError(
241
+ f"When all weights are explicitly provided, they must sum to 1.0. "
242
+ f"Got format_reward_weight={self.format_reward_weight}, "
243
+ f"graders={self.graders}"
244
+ )
245
+
246
+ def _normalize_grader_weights(self) -> None:
247
+ """Normalize only grader weights to fill (1 - format_reward_weight)"""
248
+ total_grader_weight = sum(grader.weight for grader in self.graders) # type: ignore[misc]
249
+ target_grader_total = 1.0 - self.format_reward_weight
250
+
251
+ # only normalize if weights aren't already properly normalized
252
+ if abs(total_grader_weight - target_grader_total) > 1e-10:
253
+ scale_factor = target_grader_total / total_grader_weight
254
+ for grader in self.graders:
255
+ original_weight = grader.weight
256
+ grader.weight *= scale_factor # type: ignore[operator]
257
+ log_info(
258
+ f"{grader.type}'s weight scaled from {original_weight} to {grader.weight:.2f}"
259
+ )
260
+
261
+ def _initialize_grader_weights(self) -> None:
262
+ """Initialize all grader weights when none are provided"""
263
+ for grader in self.graders:
264
+ grader.weight = 1.0
265
+
266
+
267
+ class TrainingConfig(BaseModel):
268
+ # training file ID
269
+ training_files: List[str]
270
+ # base model string
271
+ model: str
272
+ # number of epochs to train for
273
+ n_epochs: int
274
+ # training learning rate
275
+ learning_rate: float
276
+ # number of checkpoints to save
277
+ n_checkpoints: int | None = None
278
+ # training batch size
279
+ batch_size: int = Field(..., ge=1, le=1024)
280
+ # up to 40 character suffix for output model name
281
+ experiment_name: str | None = None
282
+ # sequence length
283
+ max_length: int = 2500
284
+ # # weights & biases api key
285
+ # wandb_key: str | None = None
286
+ # IFT by default
287
+ pre_train: bool = False
288
+ # fine-tune type
289
+ fine_tune_type: FineTuneType = FineTuneType.STANDARD
290
+ # LoRA config
291
+ lora_config: Optional[LoRAConfig] = None
292
+ # reward_components are REINFORCEMENT-specific
293
+ reward_components: Optional[RewardComponents] = None
294
+
295
+ @model_validator(mode="after")
296
+ def validate_reward_components(self) -> "TrainingConfig":
297
+ # TODO: re-enable the below and make reward_components required for REINFORCEMENT. Disabled for now for backwards-compatibility
298
+ # if (
299
+ # self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
300
+ # and not self.reward_components
301
+ # ):
302
+ # raise ValueError("REINFORCEMENT fine-tuning requires reward components")
303
+ if (
304
+ self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
305
+ and not self.reward_components
306
+ ):
307
+ self.reward_components = RewardComponents(
308
+ format_reward_weight=0.1,
309
+ graders=[Grader(type=GraderType.MATH_ACCURACY, weight=0.9)],
310
+ )
311
+ if self.fine_tune_type == FineTuneType.STANDARD and self.reward_components:
312
+ raise ValueError(
313
+ "Reward components are incompatible with standard fine-tuning"
314
+ )
315
+ if self.fine_tune_type == FineTuneType.PREFERENCE and self.reward_components:
316
+ raise ValueError(
317
+ "Reward components are incompatible with preference fine-tuning"
318
+ )
319
+
320
+ return self
321
+
322
+ @field_validator("fine_tune_type")
323
+ def validate_fine_tune_type(cls, v: Any) -> Any:
324
+ if v == FineTuneType.GRPO:
325
+ warnings.warn(
326
+ "FineTuneType.GRPO is deprecated and will be removed in a future version. Use FineTuneType.REINFORCEMENT",
327
+ DeprecationWarning,
328
+ stacklevel=2,
329
+ )
330
+ return v
331
+
332
+
333
+ class AcceleratorType(str, Enum):
334
+ GAUDI2 = "GAUDI2"
335
+ GAUDI3 = "GAUDI3"
336
+ A100 = "A100"
337
+ A10 = "A10"
338
+ H100 = "H100"
339
+ MI300X = "MI300X"
340
+
341
+
342
+ class InfrastructureConfig(BaseModel):
343
+ accel_type: AcceleratorType
344
+ n_accel: int
345
+ n_node: int = 1
346
+
347
+
348
+ class FinetuneRequest(BaseModel):
349
+ """
350
+ Fine-tune request type
351
+ """
352
+
353
+ project_id: int
354
+ training_config: TrainingConfig
355
+ infrastructure_config: InfrastructureConfig
356
+
357
+
358
+ class FinetuneResponse(BaseModel):
359
+ """
360
+ Fine-tune API response type
361
+ """
362
+
363
+ # job ID
364
+ id: str | None = None
365
+ # fine-tune type
366
+ fine_tune_type: FineTuneType = FineTuneType.STANDARD
367
+ reward_components: Optional[RewardComponents] = None
368
+ # training file id
369
+ training_files: List[str] | None = None
370
+ # validation file id
371
+ # validation_files: str | None = None TODO
372
+ # base model name
373
+ model: str | None = None
374
+ accel_type: AcceleratorType
375
+ n_accel: int
376
+ n_node: int | None = None
377
+ # number of epochs
378
+ n_epochs: int | None = None
379
+ # number of checkpoints to save
380
+ # n_checkpoints: int | None = None # TODO
381
+ # training batch size
382
+ batch_size: int | None = None
383
+ # training learning rate
384
+ learning_rate: float | None = None
385
+ # LoRA configuration returned when LoRA fine-tuning is enabled
386
+ lora_config: Optional[LoRAConfig] = None
387
+ # number of steps between evals
388
+ # eval_steps: int | None = None TODO
389
+ # created/updated datetime stamps
390
+ created_at: datetime | None = None
391
+ # updated_at: str | None = None
392
+ # up to 40 character suffix for output model name
393
+ experiment_name: str | None = None
394
+ # job status
395
+ status: FinetuneJobStatus | None = None
396
+ deleted_at: datetime | None = None
397
+
398
+ # list of fine-tune events
399
+ events: List[FinetuneEvent] | None = None
400
+ inference_available: bool = False
401
+ project_id: Optional[int] = None # TODO - fix this
402
+ completed_at: datetime | None = None
403
+ description: str | None = None
404
+
405
+ # dataset token count
406
+ # TODO
407
+ # token_count: int | None = None
408
+ # # model parameter count
409
+ # param_count: int | None = None
410
+ # # fine-tune job price
411
+ # total_price: int | None = None
412
+ # # number of epochs completed (incrementing counter)
413
+ # epochs_completed: int | None = None
414
+ # # place in job queue (decrementing counter)
415
+ # queue_depth: int | None = None
416
+ # # weights & biases project name
417
+ # wandb_project_name: str | None = None
418
+ # # weights & biases job url
419
+ # wandb_url: str | None = None
420
+ # # training file metadata
421
+ # training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
422
+ # training_file_size: int | None = Field(None, alias="TrainingFileSize")
423
+
424
+ @model_serializer(mode="wrap")
425
+ def serialize_model(
426
+ self, handler: Callable[[Any], dict[str, Any]]
427
+ ) -> dict[str, Any]:
428
+ # Remove 'reward_components' if it's None
429
+ dump_dict = handler(self)
430
+ if dump_dict.get("reward_components") is None:
431
+ del dump_dict["reward_components"]
432
+ return dump_dict
433
+
434
+
435
+ class FinetuneList(BaseModel):
436
+ # object type
437
+ object: Literal["list"] | None = None
438
+ # list of fine-tune job objects
439
+ data: List[FinetuneResponse] | None = None
440
+
441
+
442
+ class FinetuneListEvents(BaseModel):
443
+ # object type
444
+ object: Literal["list"] | None = None
445
+ # list of fine-tune events
446
+ data: List[FinetuneEvent] | None = None
447
+
448
+
449
+ class FinetuneDownloadResult(BaseModel):
450
+ # object type
451
+ object: Literal["local"] | None = None
452
+ # fine-tune job id
453
+ id: str | None = None
454
+ # checkpoint step number
455
+ checkpoint_step: int | None = None
456
+ # local path filename
457
+ filename: str | None = None
458
+ # size in bytes
459
+ size: int | None = None
@@ -62,6 +62,15 @@ class CreateRunPython(CreateTool):
62
62
  config: RunPythonConfig
63
63
 
64
64
 
65
+ class AgentAsToolConfig(ToolConfig):
66
+ agent_id: str
67
+
68
+
69
+ class CreateAgentAsTool(CreateTool):
70
+ type: Literal[ToolType.AGENT_AS_TOOL] = ToolType.AGENT_AS_TOOL
71
+ config: AgentAsToolConfig
72
+
73
+
65
74
  class Tool(BaseModel):
66
75
  type: ToolType
67
76
  name: Annotated[str, AfterValidator(validate_length)]
@@ -88,13 +97,19 @@ class RunPythonTool(Tool):
88
97
  config: RunPythonConfig
89
98
 
90
99
 
100
+ class AgentAsTool(Tool):
101
+ type: Literal[ToolType.AGENT_AS_TOOL] = ToolType.AGENT_AS_TOOL
102
+ config: AgentAsToolConfig
103
+
104
+
91
105
  CreateToolRequest = Annotated[
92
- Union[CreateFileSearch, CreateRunPython, CreateWebSearch],
106
+ Union[CreateFileSearch, CreateRunPython, CreateWebSearch, CreateAgentAsTool],
93
107
  Field(discriminator="type"),
94
108
  ]
95
109
 
96
110
  ToolResponse = Annotated[
97
- Union[FileSearchTool, WebSearchTool, RunPythonTool], Field(discriminator="type")
111
+ Union[FileSearchTool, WebSearchTool, RunPythonTool, AgentAsTool],
112
+ Field(discriminator="type"),
98
113
  ]
99
114
 
100
115
 
@@ -110,6 +125,10 @@ class UpdateRunPython(CreateRunPython):
110
125
  pass
111
126
 
112
127
 
128
+ class UpdateAgentAsTool(CreateAgentAsTool):
129
+ pass
130
+
131
+
113
132
  UpdateToolRequest = Annotated[
114
133
  Union[UpdateFileSearch, UpdateWebSearch, UpdateRunPython],
115
134
  Field(discriminator="type"),
@@ -117,7 +136,7 @@ UpdateToolRequest = Annotated[
117
136
 
118
137
 
119
138
  class GetToolsResponse(BaseModel):
120
- """Response schema for paginated tool list."""
139
+ """Response schema for a paginated tool list."""
121
140
 
122
141
  data: list[ToolResponse]
123
142
  total: Optional[int] = None
@@ -1,256 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from datetime import datetime
4
- from enum import Enum
5
- from typing import Any, Dict, List, Literal, Optional
6
-
7
- from pydantic import Field
8
-
9
- from seekrai.types.abstract import BaseModel
10
- from seekrai.types.common import (
11
- ObjectType,
12
- )
13
-
14
-
15
- class FinetuneJobStatus(str, Enum):
16
- """
17
- Possible fine-tune job status
18
- """
19
-
20
- STATUS_PENDING = "pending"
21
- STATUS_QUEUED = "queued"
22
- STATUS_RUNNING = "running"
23
- # STATUS_COMPRESSING = "compressing"
24
- # STATUS_UPLOADING = "uploading"
25
- STATUS_CANCEL_REQUESTED = "cancel_requested"
26
- STATUS_CANCELLED = "cancelled"
27
- STATUS_FAILED = "failed"
28
- STATUS_COMPLETED = "completed"
29
- STATUS_DELETED = "deleted"
30
-
31
-
32
- class FinetuneEventLevels(str, Enum):
33
- """
34
- Fine-tune job event status levels
35
- """
36
-
37
- NULL = ""
38
- INFO = "Info"
39
- WARNING = "Warning"
40
- ERROR = "Error"
41
- LEGACY_INFO = "info"
42
- LEGACY_IWARNING = "warning"
43
- LEGACY_IERROR = "error"
44
-
45
-
46
- class FinetuneEventType(str, Enum):
47
- """
48
- Fine-tune job event types
49
- """
50
-
51
- JOB_PENDING = "JOB_PENDING"
52
- JOB_START = "JOB_START"
53
- JOB_STOPPED = "JOB_STOPPED"
54
- MODEL_DOWNLOADING = "MODEL_DOWNLOADING"
55
- MODEL_DOWNLOAD_COMPLETE = "MODEL_DOWNLOAD_COMPLETE"
56
- TRAINING_DATA_DOWNLOADING = "TRAINING_DATA_DOWNLOADING"
57
- TRAINING_DATA_DOWNLOAD_COMPLETE = "TRAINING_DATA_DOWNLOAD_COMPLETE"
58
- VALIDATION_DATA_DOWNLOADING = "VALIDATION_DATA_DOWNLOADING"
59
- VALIDATION_DATA_DOWNLOAD_COMPLETE = "VALIDATION_DATA_DOWNLOAD_COMPLETE"
60
- WANDB_INIT = "WANDB_INIT"
61
- TRAINING_START = "TRAINING_START"
62
- CHECKPOINT_SAVE = "CHECKPOINT_SAVE"
63
- BILLING_LIMIT = "BILLING_LIMIT"
64
- EPOCH_COMPLETE = "EPOCH_COMPLETE"
65
- TRAINING_COMPLETE = "TRAINING_COMPLETE"
66
- MODEL_COMPRESSING = "COMPRESSING_MODEL"
67
- MODEL_COMPRESSION_COMPLETE = "MODEL_COMPRESSION_COMPLETE"
68
- MODEL_UPLOADING = "MODEL_UPLOADING"
69
- MODEL_UPLOAD_COMPLETE = "MODEL_UPLOAD_COMPLETE"
70
- JOB_COMPLETE = "JOB_COMPLETE"
71
- JOB_ERROR = "JOB_ERROR"
72
- CANCEL_REQUESTED = "CANCEL_REQUESTED"
73
- JOB_RESTARTED = "JOB_RESTARTED"
74
- REFUND = "REFUND"
75
- WARNING = "WARNING"
76
-
77
-
78
- class FineTuneType(str, Enum):
79
- STANDARD = "STANDARD"
80
- PREFERENCE = "PREFERENCE"
81
- GRPO = "GRPO"
82
-
83
-
84
- class FinetuneEvent(BaseModel):
85
- """
86
- Fine-tune event type
87
- """
88
-
89
- # object type
90
- object: Literal[ObjectType.FinetuneEvent]
91
- # created at datetime stamp
92
- created_at: datetime | None = None
93
- # metrics that we expose
94
- loss: float | None = None
95
- epoch: float | None = None
96
-
97
-
98
- class LoRAConfig(BaseModel):
99
- r: int = Field(8, gt=0, description="Rank of the update matrices.")
100
- alpha: int = Field(32, gt=0, description="Scaling factor applied to LoRA updates.")
101
- dropout: float = Field(
102
- 0.1,
103
- ge=0.0,
104
- le=1.0,
105
- description="Fraction of LoRA neurons dropped during training.",
106
- )
107
- bias: Literal["none", "all", "lora_only"] = Field(
108
- "none",
109
- description="Bias terms to train; choose from 'none', 'all', or 'lora_only'.",
110
- )
111
- extras: Dict[str, Any] = Field(default_factory=dict)
112
-
113
-
114
- class TrainingConfig(BaseModel):
115
- # training file ID
116
- training_files: List[str]
117
- # base model string
118
- model: str
119
- # number of epochs to train for
120
- n_epochs: int
121
- # training learning rate
122
- learning_rate: float
123
- # number of checkpoints to save
124
- n_checkpoints: int | None = None
125
- # training batch size
126
- batch_size: int = Field(..., ge=1, le=1024)
127
- # up to 40 character suffix for output model name
128
- experiment_name: str | None = None
129
- # sequence length
130
- max_length: int = 2500
131
- # # weights & biases api key
132
- # wandb_key: str | None = None
133
- # IFT by default
134
- pre_train: bool = False
135
- # fine-tune type
136
- fine_tune_type: FineTuneType = FineTuneType.STANDARD
137
- # LoRA config
138
- lora_config: Optional[LoRAConfig] = None
139
-
140
-
141
- class AcceleratorType(str, Enum):
142
- GAUDI2 = "GAUDI2"
143
- GAUDI3 = "GAUDI3"
144
- A100 = "A100"
145
- A10 = "A10"
146
- H100 = "H100"
147
- MI300X = "MI300X"
148
-
149
-
150
- class InfrastructureConfig(BaseModel):
151
- accel_type: AcceleratorType
152
- n_accel: int
153
- n_node: int = 1
154
-
155
-
156
- class FinetuneRequest(BaseModel):
157
- """
158
- Fine-tune request type
159
- """
160
-
161
- project_id: int
162
- training_config: TrainingConfig
163
- infrastructure_config: InfrastructureConfig
164
-
165
-
166
- class FinetuneResponse(BaseModel):
167
- """
168
- Fine-tune API response type
169
- """
170
-
171
- # job ID
172
- id: str | None = None
173
- # fine-tune type
174
- fine_tune_type: FineTuneType = FineTuneType.STANDARD
175
- # training file id
176
- training_files: List[str] | None = None
177
- # validation file id
178
- # validation_files: str | None = None TODO
179
- # base model name
180
- model: str | None = None
181
- accel_type: AcceleratorType
182
- n_accel: int
183
- n_node: int | None = None
184
- # number of epochs
185
- n_epochs: int | None = None
186
- # number of checkpoints to save
187
- # n_checkpoints: int | None = None # TODO
188
- # training batch size
189
- batch_size: int | None = None
190
- # training learning rate
191
- learning_rate: float | None = None
192
- # LoRA configuration returned when LoRA fine-tuning is enabled
193
- lora_config: Optional[LoRAConfig] = None
194
- # number of steps between evals
195
- # eval_steps: int | None = None TODO
196
- # created/updated datetime stamps
197
- created_at: datetime | None = None
198
- # updated_at: str | None = None
199
- # up to 40 character suffix for output model name
200
- experiment_name: str | None = None
201
- # job status
202
- status: FinetuneJobStatus | None = None
203
- deleted_at: datetime | None = None
204
-
205
- # list of fine-tune events
206
- events: List[FinetuneEvent] | None = None
207
- inference_available: bool = False
208
- project_id: Optional[int] = None # TODO - fix this
209
- completed_at: datetime | None = None
210
- description: str | None = None
211
-
212
- # dataset token count
213
- # TODO
214
- # token_count: int | None = None
215
- # # model parameter count
216
- # param_count: int | None = None
217
- # # fine-tune job price
218
- # total_price: int | None = None
219
- # # number of epochs completed (incrementing counter)
220
- # epochs_completed: int | None = None
221
- # # place in job queue (decrementing counter)
222
- # queue_depth: int | None = None
223
- # # weights & biases project name
224
- # wandb_project_name: str | None = None
225
- # # weights & biases job url
226
- # wandb_url: str | None = None
227
- # # training file metadata
228
- # training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
229
- # training_file_size: int | None = Field(None, alias="TrainingFileSize")
230
-
231
-
232
- class FinetuneList(BaseModel):
233
- # object type
234
- object: Literal["list"] | None = None
235
- # list of fine-tune job objects
236
- data: List[FinetuneResponse] | None = None
237
-
238
-
239
- class FinetuneListEvents(BaseModel):
240
- # object type
241
- object: Literal["list"] | None = None
242
- # list of fine-tune events
243
- data: List[FinetuneEvent] | None = None
244
-
245
-
246
- class FinetuneDownloadResult(BaseModel):
247
- # object type
248
- object: Literal["local"] | None = None
249
- # fine-tune job id
250
- id: str | None = None
251
- # checkpoint step number
252
- checkpoint_step: int | None = None
253
- # local path filename
254
- filename: str | None = None
255
- # size in bytes
256
- size: int | None = None
File without changes
File without changes
File without changes
File without changes