seekrai 0.5.25__tar.gz → 0.5.28__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {seekrai-0.5.25 → seekrai-0.5.28}/PKG-INFO +1 -1
  2. {seekrai-0.5.25 → seekrai-0.5.28}/pyproject.toml +1 -1
  3. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/chat/completions.py +12 -6
  4. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/chat_completions.py +1 -1
  5. seekrai-0.5.28/src/seekrai/types/finetune.py +459 -0
  6. seekrai-0.5.25/src/seekrai/types/finetune.py +0 -256
  7. {seekrai-0.5.25 → seekrai-0.5.28}/LICENSE +0 -0
  8. {seekrai-0.5.25 → seekrai-0.5.28}/README.md +0 -0
  9. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/__init__.py +0 -0
  10. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/abstract/__init__.py +0 -0
  11. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/abstract/api_requestor.py +0 -0
  12. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/abstract/response_parsing.py +0 -0
  13. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/client.py +0 -0
  14. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/constants.py +0 -0
  15. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/error.py +0 -0
  16. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/filemanager.py +0 -0
  17. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/__init__.py +0 -0
  18. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/agents/__init__.py +0 -0
  19. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/agents/agent_inference.py +0 -0
  20. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/agents/agent_observability.py +0 -0
  21. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/agents/agents.py +0 -0
  22. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/agents/python_functions.py +0 -0
  23. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/agents/threads.py +0 -0
  24. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/alignment.py +0 -0
  25. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/chat/__init__.py +0 -0
  26. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/completions.py +0 -0
  27. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/deployments.py +0 -0
  28. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/embeddings.py +0 -0
  29. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/explainability.py +0 -0
  30. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/files.py +0 -0
  31. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/finetune.py +0 -0
  32. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/images.py +0 -0
  33. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/ingestion.py +0 -0
  34. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/models.py +0 -0
  35. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/projects.py +0 -0
  36. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/resource_base.py +0 -0
  37. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/tools.py +0 -0
  38. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/resources/vectordb.py +0 -0
  39. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/seekrflow_response.py +0 -0
  40. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/__init__.py +0 -0
  41. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/abstract.py +0 -0
  42. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/__init__.py +0 -0
  43. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/agent.py +0 -0
  44. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/observability.py +0 -0
  45. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/python_functions.py +0 -0
  46. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/runs.py +0 -0
  47. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/threads.py +0 -0
  48. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/__init__.py +0 -0
  49. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/env_model_config.py +0 -0
  50. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/schemas/__init__.py +0 -0
  51. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/schemas/file_search.py +0 -0
  52. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/schemas/file_search_env.py +0 -0
  53. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/schemas/run_python.py +0 -0
  54. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/schemas/run_python_env.py +0 -0
  55. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/schemas/web_search.py +0 -0
  56. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/schemas/web_search_env.py +0 -0
  57. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/tool.py +0 -0
  58. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/agents/tools/tool_types.py +0 -0
  59. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/alignment.py +0 -0
  60. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/common.py +0 -0
  61. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/completions.py +0 -0
  62. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/deployments.py +0 -0
  63. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/embeddings.py +0 -0
  64. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/enums.py +0 -0
  65. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/error.py +0 -0
  66. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/explainability.py +0 -0
  67. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/files.py +0 -0
  68. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/images.py +0 -0
  69. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/ingestion.py +0 -0
  70. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/models.py +0 -0
  71. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/projects.py +0 -0
  72. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/tools.py +0 -0
  73. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/types/vectordb.py +0 -0
  74. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/utils/__init__.py +0 -0
  75. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/utils/_log.py +0 -0
  76. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/utils/api_helpers.py +0 -0
  77. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/utils/files.py +0 -0
  78. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/utils/tools.py +0 -0
  79. {seekrai-0.5.25 → seekrai-0.5.28}/src/seekrai/version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: seekrai
3
- Version: 0.5.25
3
+ Version: 0.5.28
4
4
  Summary: Python client for SeekrAI
5
5
  License: Apache-2.0
6
6
  License-File: LICENSE
@@ -13,7 +13,7 @@ build-backend = "poetry.core.masonry.api"
13
13
 
14
14
  [tool.poetry]
15
15
  name = "seekrai"
16
- version = "0.5.25"
16
+ version = "0.5.28"
17
17
  authors = [
18
18
  "SeekrFlow <support@seekr.com>"
19
19
  ]
@@ -20,7 +20,7 @@ class ChatCompletions:
20
20
  def create(
21
21
  self,
22
22
  *,
23
- messages: List[Dict[str, str]],
23
+ messages: List[Dict[str, Any]],
24
24
  model: str,
25
25
  max_completion_tokens: int | None = None,
26
26
  max_tokens: int | None = 512,
@@ -43,8 +43,11 @@ class ChatCompletions:
43
43
  Method to generate completions based on a given prompt using a specified model.
44
44
 
45
45
  Args:
46
- messages (List[Dict[str, str]]): A list of messages in the format
47
- `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]`
46
+ messages (List[Dict[str, Any]]): A list of messages in the format
47
+ `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]` or
48
+ `[{"role": seekrai.types.chat_completions.MessageRole, "content": PARTS}, ...]`
49
+ where PARTS is a list of content dicts, e.g. {"type": "text", "text": "..."} or
50
+ {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}
48
51
  model (str): The name of the model to query.
49
52
  max_completion_tokens (int, optional): The maximum number of tokens the output can contain.
50
53
  max_tokens (int, optional): The maximum number of tokens to generate.
@@ -141,7 +144,7 @@ class AsyncChatCompletions:
141
144
  async def create(
142
145
  self,
143
146
  *,
144
- messages: List[Dict[str, str]],
147
+ messages: List[Dict[str, Any]],
145
148
  model: str,
146
149
  max_completion_tokens: int | None = None,
147
150
  max_tokens: int | None = 512,
@@ -164,8 +167,11 @@ class AsyncChatCompletions:
164
167
  Async method to generate completions based on a given prompt using a specified model.
165
168
 
166
169
  Args:
167
- messages (List[Dict[str, str]]): A list of messages in the format
168
- `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]`
170
+ messages (List[Dict[str, Any]]): A list of messages in the format
171
+ `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]` or
172
+ `[{"role": seekrai.types.chat_completions.MessageRole, "content": PARTS}, ...]`
173
+ where PARTS is a list of content dicts, e.g. {"type": "text", "text": "..."} or
174
+ {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}
169
175
  model (str): The name of the model to query.
170
176
  max_completion_tokens (int, optional): The maximum number of tokens the output can contain.
171
177
  max_tokens (int, optional): The maximum number of tokens to generate.
@@ -40,7 +40,7 @@ class ToolCalls(BaseModel):
40
40
 
41
41
  class ChatCompletionMessage(BaseModel):
42
42
  role: MessageRole
43
- content: str | None = None
43
+ content: str | List[Dict[str, Any]] | None = None
44
44
  # tool_calls: List[ToolCalls] | None = None
45
45
 
46
46
 
@@ -0,0 +1,459 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from datetime import datetime
5
+ from enum import Enum
6
+ from typing import Any, Callable, Dict, List, Literal, Optional
7
+
8
+ from pydantic import Field, field_validator, model_serializer, model_validator
9
+
10
+ from seekrai.types.abstract import BaseModel
11
+ from seekrai.types.common import (
12
+ ObjectType,
13
+ )
14
+ from seekrai.utils._log import log_info
15
+
16
+
17
+ class FinetuneJobStatus(str, Enum):
18
+ """
19
+ Possible fine-tune job status
20
+ """
21
+
22
+ STATUS_PENDING = "pending"
23
+ STATUS_QUEUED = "queued"
24
+ STATUS_RUNNING = "running"
25
+ # STATUS_COMPRESSING = "compressing"
26
+ # STATUS_UPLOADING = "uploading"
27
+ STATUS_CANCEL_REQUESTED = "cancel_requested"
28
+ STATUS_CANCELLED = "cancelled"
29
+ STATUS_FAILED = "failed"
30
+ STATUS_COMPLETED = "completed"
31
+ STATUS_DELETED = "deleted"
32
+
33
+
34
+ class FinetuneEventLevels(str, Enum):
35
+ """
36
+ Fine-tune job event status levels
37
+ """
38
+
39
+ NULL = ""
40
+ INFO = "Info"
41
+ WARNING = "Warning"
42
+ ERROR = "Error"
43
+ LEGACY_INFO = "info"
44
+ LEGACY_IWARNING = "warning"
45
+ LEGACY_IERROR = "error"
46
+
47
+
48
+ class FinetuneEventType(str, Enum):
49
+ """
50
+ Fine-tune job event types
51
+ """
52
+
53
+ JOB_PENDING = "JOB_PENDING"
54
+ JOB_START = "JOB_START"
55
+ JOB_STOPPED = "JOB_STOPPED"
56
+ MODEL_DOWNLOADING = "MODEL_DOWNLOADING"
57
+ MODEL_DOWNLOAD_COMPLETE = "MODEL_DOWNLOAD_COMPLETE"
58
+ TRAINING_DATA_DOWNLOADING = "TRAINING_DATA_DOWNLOADING"
59
+ TRAINING_DATA_DOWNLOAD_COMPLETE = "TRAINING_DATA_DOWNLOAD_COMPLETE"
60
+ VALIDATION_DATA_DOWNLOADING = "VALIDATION_DATA_DOWNLOADING"
61
+ VALIDATION_DATA_DOWNLOAD_COMPLETE = "VALIDATION_DATA_DOWNLOAD_COMPLETE"
62
+ WANDB_INIT = "WANDB_INIT"
63
+ TRAINING_START = "TRAINING_START"
64
+ CHECKPOINT_SAVE = "CHECKPOINT_SAVE"
65
+ BILLING_LIMIT = "BILLING_LIMIT"
66
+ EPOCH_COMPLETE = "EPOCH_COMPLETE"
67
+ TRAINING_COMPLETE = "TRAINING_COMPLETE"
68
+ MODEL_COMPRESSING = "COMPRESSING_MODEL"
69
+ MODEL_COMPRESSION_COMPLETE = "MODEL_COMPRESSION_COMPLETE"
70
+ MODEL_UPLOADING = "MODEL_UPLOADING"
71
+ MODEL_UPLOAD_COMPLETE = "MODEL_UPLOAD_COMPLETE"
72
+ JOB_COMPLETE = "JOB_COMPLETE"
73
+ JOB_ERROR = "JOB_ERROR"
74
+ CANCEL_REQUESTED = "CANCEL_REQUESTED"
75
+ JOB_RESTARTED = "JOB_RESTARTED"
76
+ REFUND = "REFUND"
77
+ WARNING = "WARNING"
78
+
79
+
80
+ class FineTuneType(str, Enum):
81
+ STANDARD = "STANDARD"
82
+ GRPO = "GRPO" # deprecated
83
+ PREFERENCE = "PREFERENCE"
84
+ REINFORCEMENT = "REINFORCEMENT"
85
+
86
+
87
+ class GraderType(str, Enum):
88
+ FORMAT_CHECK = "format_check"
89
+ MATH_ACCURACY = "math_accuracy"
90
+ STRING_CHECK = "string_check"
91
+ TEXT_SIMILARITY = "text_similarity"
92
+
93
+
94
+ class StringOperation(str, Enum):
95
+ EQUALS = "equals"
96
+ NOT_EQUALS = "not_equals"
97
+ CONTAINS = "contains"
98
+ CASE_INSENSITIVE_CONTAINS = "case_insensitive_contains"
99
+
100
+
101
+ class TextSimilarityOperation(str, Enum):
102
+ BLEU = "bleu"
103
+ ROUGE = "rouge"
104
+
105
+
106
+ class FinetuneEvent(BaseModel):
107
+ """
108
+ Fine-tune event type
109
+ """
110
+
111
+ # object type
112
+ object: Literal[ObjectType.FinetuneEvent]
113
+ # created at datetime stamp
114
+ created_at: datetime | None = None
115
+ # metrics that we expose
116
+ loss: float | None = None
117
+ epoch: float | None = None
118
+ reward: float | None = None
119
+
120
+ @model_serializer(mode="wrap")
121
+ def serialize_model(
122
+ self, handler: Callable[[Any], dict[str, Any]]
123
+ ) -> dict[str, Any]:
124
+ # Remove 'reward' if it's None
125
+ dump_dict = handler(self)
126
+ if dump_dict.get("reward") is None:
127
+ del dump_dict["reward"]
128
+ return dump_dict
129
+
130
+
131
+ class LoRAConfig(BaseModel):
132
+ r: int = Field(8, gt=0, description="Rank of the update matrices.")
133
+ alpha: int = Field(32, gt=0, description="Scaling factor applied to LoRA updates.")
134
+ dropout: float = Field(
135
+ 0.1,
136
+ ge=0.0,
137
+ le=1.0,
138
+ description="Fraction of LoRA neurons dropped during training.",
139
+ )
140
+ bias: Literal["none", "all", "lora_only"] = Field(
141
+ "none",
142
+ description="Bias terms to train; choose from 'none', 'all', or 'lora_only'.",
143
+ )
144
+ extras: Dict[str, Any] = Field(default_factory=dict)
145
+
146
+
147
+ class Grader(BaseModel):
148
+ type: GraderType
149
+ weight: float | None = Field(default=None, gt=0.0, le=1.0)
150
+ operation: StringOperation | TextSimilarityOperation | None = Field(default=None)
151
+
152
+ @model_validator(mode="before")
153
+ @classmethod
154
+ def validate_operation(cls, data: Any) -> Any:
155
+ if not isinstance(data, dict):
156
+ return data
157
+
158
+ grader_type = data.get("type")
159
+ operation_value = data.get("operation")
160
+
161
+ if grader_type == GraderType.STRING_CHECK:
162
+ if not operation_value:
163
+ raise ValueError(
164
+ "string_check grader is missing required StringOperation"
165
+ )
166
+ if isinstance(operation_value, str):
167
+ try:
168
+ # Convert to enum to validate it's a valid value
169
+ data["operation"] = StringOperation(operation_value.lower())
170
+ except ValueError:
171
+ raise ValueError(
172
+ f"Invalid operation for string_check grader: "
173
+ f"expected StringOperation, but got type '{type(operation_value).__name__}' with value '{operation_value}'"
174
+ )
175
+ elif grader_type == GraderType.TEXT_SIMILARITY:
176
+ if not operation_value:
177
+ raise ValueError(
178
+ "text_similarity grader is missing required TextSimilarityOperation"
179
+ )
180
+ if isinstance(operation_value, str):
181
+ try:
182
+ data["operation"] = TextSimilarityOperation(operation_value.lower())
183
+ except ValueError:
184
+ raise ValueError(
185
+ f"Invalid operation for text_similarity grader: "
186
+ f"expected TextSimilarityOperation, got type '{type(operation_value).__name__}' with value '{operation_value}'"
187
+ )
188
+
189
+ elif grader_type in (GraderType.FORMAT_CHECK, GraderType.MATH_ACCURACY):
190
+ if operation_value:
191
+ raise ValueError(f"{grader_type} grader cannot have an operation")
192
+ data["operation"] = None
193
+
194
+ return data
195
+
196
+
197
+ class RewardComponents(BaseModel):
198
+ format_reward_weight: float = Field(default=0.1, gt=0.0, le=1.0)
199
+ graders: list[Grader] = Field(min_length=1)
200
+
201
+ @model_validator(mode="after")
202
+ def validate_weights(self) -> "RewardComponents":
203
+ is_format_weight_specified = "format_reward_weight" in self.model_fields_set
204
+
205
+ grader_weights_specified = [
206
+ grader.weight is not None for grader in self.graders
207
+ ]
208
+
209
+ all_graders_have_weights = all(grader_weights_specified)
210
+ some_graders_have_weights = any(grader_weights_specified) and not all(
211
+ grader_weights_specified
212
+ )
213
+ no_graders_have_weights = not any(grader_weights_specified)
214
+
215
+ if some_graders_have_weights:
216
+ raise ValueError(
217
+ "Only some graders have weights specified. Either all graders must have weights specified, or none of them."
218
+ )
219
+
220
+ if all_graders_have_weights and is_format_weight_specified:
221
+ self._validate_weights_sum_to_one()
222
+
223
+ elif all_graders_have_weights and not is_format_weight_specified:
224
+ self._normalize_grader_weights()
225
+
226
+ elif no_graders_have_weights:
227
+ self._initialize_grader_weights()
228
+ self._normalize_grader_weights()
229
+
230
+ return self
231
+
232
+ def _validate_weights_sum_to_one(self) -> None:
233
+ """Validate that format_reward_weight and grader weights sum to 1.0"""
234
+ total_weight = self.format_reward_weight + sum( # type: ignore[operator]
235
+ grader.weight # type: ignore[misc]
236
+ for grader in self.graders
237
+ )
238
+
239
+ if abs(total_weight - 1.0) > 1e-10:
240
+ raise ValueError(
241
+ f"When all weights are explicitly provided, they must sum to 1.0. "
242
+ f"Got format_reward_weight={self.format_reward_weight}, "
243
+ f"graders={self.graders}"
244
+ )
245
+
246
+ def _normalize_grader_weights(self) -> None:
247
+ """Normalize only grader weights to fill (1 - format_reward_weight)"""
248
+ total_grader_weight = sum(grader.weight for grader in self.graders) # type: ignore[misc]
249
+ target_grader_total = 1.0 - self.format_reward_weight
250
+
251
+ # only normalize if weights aren't already properly normalized
252
+ if abs(total_grader_weight - target_grader_total) > 1e-10:
253
+ scale_factor = target_grader_total / total_grader_weight
254
+ for grader in self.graders:
255
+ original_weight = grader.weight
256
+ grader.weight *= scale_factor # type: ignore[operator]
257
+ log_info(
258
+ f"{grader.type}'s weight scaled from {original_weight} to {grader.weight:.2f}"
259
+ )
260
+
261
+ def _initialize_grader_weights(self) -> None:
262
+ """Initialize all grader weights when none are provided"""
263
+ for grader in self.graders:
264
+ grader.weight = 1.0
265
+
266
+
267
+ class TrainingConfig(BaseModel):
268
+ # training file ID
269
+ training_files: List[str]
270
+ # base model string
271
+ model: str
272
+ # number of epochs to train for
273
+ n_epochs: int
274
+ # training learning rate
275
+ learning_rate: float
276
+ # number of checkpoints to save
277
+ n_checkpoints: int | None = None
278
+ # training batch size
279
+ batch_size: int = Field(..., ge=1, le=1024)
280
+ # up to 40 character suffix for output model name
281
+ experiment_name: str | None = None
282
+ # sequence length
283
+ max_length: int = 2500
284
+ # # weights & biases api key
285
+ # wandb_key: str | None = None
286
+ # IFT by default
287
+ pre_train: bool = False
288
+ # fine-tune type
289
+ fine_tune_type: FineTuneType = FineTuneType.STANDARD
290
+ # LoRA config
291
+ lora_config: Optional[LoRAConfig] = None
292
+ # reward_components are REINFORCEMENT-specific
293
+ reward_components: Optional[RewardComponents] = None
294
+
295
+ @model_validator(mode="after")
296
+ def validate_reward_components(self) -> "TrainingConfig":
297
+ # TODO: re-enable the below and make reward_components required for REINFORCEMENT. Disabled for now for backwards-compatibility
298
+ # if (
299
+ # self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
300
+ # and not self.reward_components
301
+ # ):
302
+ # raise ValueError("REINFORCEMENT fine-tuning requires reward components")
303
+ if (
304
+ self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
305
+ and not self.reward_components
306
+ ):
307
+ self.reward_components = RewardComponents(
308
+ format_reward_weight=0.1,
309
+ graders=[Grader(type=GraderType.MATH_ACCURACY, weight=0.9)],
310
+ )
311
+ if self.fine_tune_type == FineTuneType.STANDARD and self.reward_components:
312
+ raise ValueError(
313
+ "Reward components are incompatible with standard fine-tuning"
314
+ )
315
+ if self.fine_tune_type == FineTuneType.PREFERENCE and self.reward_components:
316
+ raise ValueError(
317
+ "Reward components are incompatible with preference fine-tuning"
318
+ )
319
+
320
+ return self
321
+
322
+ @field_validator("fine_tune_type")
323
+ def validate_fine_tune_type(cls, v: Any) -> Any:
324
+ if v == FineTuneType.GRPO:
325
+ warnings.warn(
326
+ "FineTuneType.GRPO is deprecated and will be removed in a future version. Use FineTuneType.REINFORCEMENT",
327
+ DeprecationWarning,
328
+ stacklevel=2,
329
+ )
330
+ return v
331
+
332
+
333
+ class AcceleratorType(str, Enum):
334
+ GAUDI2 = "GAUDI2"
335
+ GAUDI3 = "GAUDI3"
336
+ A100 = "A100"
337
+ A10 = "A10"
338
+ H100 = "H100"
339
+ MI300X = "MI300X"
340
+
341
+
342
+ class InfrastructureConfig(BaseModel):
343
+ accel_type: AcceleratorType
344
+ n_accel: int
345
+ n_node: int = 1
346
+
347
+
348
+ class FinetuneRequest(BaseModel):
349
+ """
350
+ Fine-tune request type
351
+ """
352
+
353
+ project_id: int
354
+ training_config: TrainingConfig
355
+ infrastructure_config: InfrastructureConfig
356
+
357
+
358
+ class FinetuneResponse(BaseModel):
359
+ """
360
+ Fine-tune API response type
361
+ """
362
+
363
+ # job ID
364
+ id: str | None = None
365
+ # fine-tune type
366
+ fine_tune_type: FineTuneType = FineTuneType.STANDARD
367
+ reward_components: Optional[RewardComponents] = None
368
+ # training file id
369
+ training_files: List[str] | None = None
370
+ # validation file id
371
+ # validation_files: str | None = None TODO
372
+ # base model name
373
+ model: str | None = None
374
+ accel_type: AcceleratorType
375
+ n_accel: int
376
+ n_node: int | None = None
377
+ # number of epochs
378
+ n_epochs: int | None = None
379
+ # number of checkpoints to save
380
+ # n_checkpoints: int | None = None # TODO
381
+ # training batch size
382
+ batch_size: int | None = None
383
+ # training learning rate
384
+ learning_rate: float | None = None
385
+ # LoRA configuration returned when LoRA fine-tuning is enabled
386
+ lora_config: Optional[LoRAConfig] = None
387
+ # number of steps between evals
388
+ # eval_steps: int | None = None TODO
389
+ # created/updated datetime stamps
390
+ created_at: datetime | None = None
391
+ # updated_at: str | None = None
392
+ # up to 40 character suffix for output model name
393
+ experiment_name: str | None = None
394
+ # job status
395
+ status: FinetuneJobStatus | None = None
396
+ deleted_at: datetime | None = None
397
+
398
+ # list of fine-tune events
399
+ events: List[FinetuneEvent] | None = None
400
+ inference_available: bool = False
401
+ project_id: Optional[int] = None # TODO - fix this
402
+ completed_at: datetime | None = None
403
+ description: str | None = None
404
+
405
+ # dataset token count
406
+ # TODO
407
+ # token_count: int | None = None
408
+ # # model parameter count
409
+ # param_count: int | None = None
410
+ # # fine-tune job price
411
+ # total_price: int | None = None
412
+ # # number of epochs completed (incrementing counter)
413
+ # epochs_completed: int | None = None
414
+ # # place in job queue (decrementing counter)
415
+ # queue_depth: int | None = None
416
+ # # weights & biases project name
417
+ # wandb_project_name: str | None = None
418
+ # # weights & biases job url
419
+ # wandb_url: str | None = None
420
+ # # training file metadata
421
+ # training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
422
+ # training_file_size: int | None = Field(None, alias="TrainingFileSize")
423
+
424
+ @model_serializer(mode="wrap")
425
+ def serialize_model(
426
+ self, handler: Callable[[Any], dict[str, Any]]
427
+ ) -> dict[str, Any]:
428
+ # Remove 'reward_components' if it's None
429
+ dump_dict = handler(self)
430
+ if dump_dict.get("reward_components") is None:
431
+ del dump_dict["reward_components"]
432
+ return dump_dict
433
+
434
+
435
+ class FinetuneList(BaseModel):
436
+ # object type
437
+ object: Literal["list"] | None = None
438
+ # list of fine-tune job objects
439
+ data: List[FinetuneResponse] | None = None
440
+
441
+
442
+ class FinetuneListEvents(BaseModel):
443
+ # object type
444
+ object: Literal["list"] | None = None
445
+ # list of fine-tune events
446
+ data: List[FinetuneEvent] | None = None
447
+
448
+
449
+ class FinetuneDownloadResult(BaseModel):
450
+ # object type
451
+ object: Literal["local"] | None = None
452
+ # fine-tune job id
453
+ id: str | None = None
454
+ # checkpoint step number
455
+ checkpoint_step: int | None = None
456
+ # local path filename
457
+ filename: str | None = None
458
+ # size in bytes
459
+ size: int | None = None
@@ -1,256 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from datetime import datetime
4
- from enum import Enum
5
- from typing import Any, Dict, List, Literal, Optional
6
-
7
- from pydantic import Field
8
-
9
- from seekrai.types.abstract import BaseModel
10
- from seekrai.types.common import (
11
- ObjectType,
12
- )
13
-
14
-
15
- class FinetuneJobStatus(str, Enum):
16
- """
17
- Possible fine-tune job status
18
- """
19
-
20
- STATUS_PENDING = "pending"
21
- STATUS_QUEUED = "queued"
22
- STATUS_RUNNING = "running"
23
- # STATUS_COMPRESSING = "compressing"
24
- # STATUS_UPLOADING = "uploading"
25
- STATUS_CANCEL_REQUESTED = "cancel_requested"
26
- STATUS_CANCELLED = "cancelled"
27
- STATUS_FAILED = "failed"
28
- STATUS_COMPLETED = "completed"
29
- STATUS_DELETED = "deleted"
30
-
31
-
32
- class FinetuneEventLevels(str, Enum):
33
- """
34
- Fine-tune job event status levels
35
- """
36
-
37
- NULL = ""
38
- INFO = "Info"
39
- WARNING = "Warning"
40
- ERROR = "Error"
41
- LEGACY_INFO = "info"
42
- LEGACY_IWARNING = "warning"
43
- LEGACY_IERROR = "error"
44
-
45
-
46
- class FinetuneEventType(str, Enum):
47
- """
48
- Fine-tune job event types
49
- """
50
-
51
- JOB_PENDING = "JOB_PENDING"
52
- JOB_START = "JOB_START"
53
- JOB_STOPPED = "JOB_STOPPED"
54
- MODEL_DOWNLOADING = "MODEL_DOWNLOADING"
55
- MODEL_DOWNLOAD_COMPLETE = "MODEL_DOWNLOAD_COMPLETE"
56
- TRAINING_DATA_DOWNLOADING = "TRAINING_DATA_DOWNLOADING"
57
- TRAINING_DATA_DOWNLOAD_COMPLETE = "TRAINING_DATA_DOWNLOAD_COMPLETE"
58
- VALIDATION_DATA_DOWNLOADING = "VALIDATION_DATA_DOWNLOADING"
59
- VALIDATION_DATA_DOWNLOAD_COMPLETE = "VALIDATION_DATA_DOWNLOAD_COMPLETE"
60
- WANDB_INIT = "WANDB_INIT"
61
- TRAINING_START = "TRAINING_START"
62
- CHECKPOINT_SAVE = "CHECKPOINT_SAVE"
63
- BILLING_LIMIT = "BILLING_LIMIT"
64
- EPOCH_COMPLETE = "EPOCH_COMPLETE"
65
- TRAINING_COMPLETE = "TRAINING_COMPLETE"
66
- MODEL_COMPRESSING = "COMPRESSING_MODEL"
67
- MODEL_COMPRESSION_COMPLETE = "MODEL_COMPRESSION_COMPLETE"
68
- MODEL_UPLOADING = "MODEL_UPLOADING"
69
- MODEL_UPLOAD_COMPLETE = "MODEL_UPLOAD_COMPLETE"
70
- JOB_COMPLETE = "JOB_COMPLETE"
71
- JOB_ERROR = "JOB_ERROR"
72
- CANCEL_REQUESTED = "CANCEL_REQUESTED"
73
- JOB_RESTARTED = "JOB_RESTARTED"
74
- REFUND = "REFUND"
75
- WARNING = "WARNING"
76
-
77
-
78
- class FineTuneType(str, Enum):
79
- STANDARD = "STANDARD"
80
- PREFERENCE = "PREFERENCE"
81
- GRPO = "GRPO"
82
-
83
-
84
- class FinetuneEvent(BaseModel):
85
- """
86
- Fine-tune event type
87
- """
88
-
89
- # object type
90
- object: Literal[ObjectType.FinetuneEvent]
91
- # created at datetime stamp
92
- created_at: datetime | None = None
93
- # metrics that we expose
94
- loss: float | None = None
95
- epoch: float | None = None
96
-
97
-
98
- class LoRAConfig(BaseModel):
99
- r: int = Field(8, gt=0, description="Rank of the update matrices.")
100
- alpha: int = Field(32, gt=0, description="Scaling factor applied to LoRA updates.")
101
- dropout: float = Field(
102
- 0.1,
103
- ge=0.0,
104
- le=1.0,
105
- description="Fraction of LoRA neurons dropped during training.",
106
- )
107
- bias: Literal["none", "all", "lora_only"] = Field(
108
- "none",
109
- description="Bias terms to train; choose from 'none', 'all', or 'lora_only'.",
110
- )
111
- extras: Dict[str, Any] = Field(default_factory=dict)
112
-
113
-
114
- class TrainingConfig(BaseModel):
115
- # training file ID
116
- training_files: List[str]
117
- # base model string
118
- model: str
119
- # number of epochs to train for
120
- n_epochs: int
121
- # training learning rate
122
- learning_rate: float
123
- # number of checkpoints to save
124
- n_checkpoints: int | None = None
125
- # training batch size
126
- batch_size: int = Field(..., ge=1, le=1024)
127
- # up to 40 character suffix for output model name
128
- experiment_name: str | None = None
129
- # sequence length
130
- max_length: int = 2500
131
- # # weights & biases api key
132
- # wandb_key: str | None = None
133
- # IFT by default
134
- pre_train: bool = False
135
- # fine-tune type
136
- fine_tune_type: FineTuneType = FineTuneType.STANDARD
137
- # LoRA config
138
- lora_config: Optional[LoRAConfig] = None
139
-
140
-
141
- class AcceleratorType(str, Enum):
142
- GAUDI2 = "GAUDI2"
143
- GAUDI3 = "GAUDI3"
144
- A100 = "A100"
145
- A10 = "A10"
146
- H100 = "H100"
147
- MI300X = "MI300X"
148
-
149
-
150
- class InfrastructureConfig(BaseModel):
151
- accel_type: AcceleratorType
152
- n_accel: int
153
- n_node: int = 1
154
-
155
-
156
- class FinetuneRequest(BaseModel):
157
- """
158
- Fine-tune request type
159
- """
160
-
161
- project_id: int
162
- training_config: TrainingConfig
163
- infrastructure_config: InfrastructureConfig
164
-
165
-
166
- class FinetuneResponse(BaseModel):
167
- """
168
- Fine-tune API response type
169
- """
170
-
171
- # job ID
172
- id: str | None = None
173
- # fine-tune type
174
- fine_tune_type: FineTuneType = FineTuneType.STANDARD
175
- # training file id
176
- training_files: List[str] | None = None
177
- # validation file id
178
- # validation_files: str | None = None TODO
179
- # base model name
180
- model: str | None = None
181
- accel_type: AcceleratorType
182
- n_accel: int
183
- n_node: int | None = None
184
- # number of epochs
185
- n_epochs: int | None = None
186
- # number of checkpoints to save
187
- # n_checkpoints: int | None = None # TODO
188
- # training batch size
189
- batch_size: int | None = None
190
- # training learning rate
191
- learning_rate: float | None = None
192
- # LoRA configuration returned when LoRA fine-tuning is enabled
193
- lora_config: Optional[LoRAConfig] = None
194
- # number of steps between evals
195
- # eval_steps: int | None = None TODO
196
- # created/updated datetime stamps
197
- created_at: datetime | None = None
198
- # updated_at: str | None = None
199
- # up to 40 character suffix for output model name
200
- experiment_name: str | None = None
201
- # job status
202
- status: FinetuneJobStatus | None = None
203
- deleted_at: datetime | None = None
204
-
205
- # list of fine-tune events
206
- events: List[FinetuneEvent] | None = None
207
- inference_available: bool = False
208
- project_id: Optional[int] = None # TODO - fix this
209
- completed_at: datetime | None = None
210
- description: str | None = None
211
-
212
- # dataset token count
213
- # TODO
214
- # token_count: int | None = None
215
- # # model parameter count
216
- # param_count: int | None = None
217
- # # fine-tune job price
218
- # total_price: int | None = None
219
- # # number of epochs completed (incrementing counter)
220
- # epochs_completed: int | None = None
221
- # # place in job queue (decrementing counter)
222
- # queue_depth: int | None = None
223
- # # weights & biases project name
224
- # wandb_project_name: str | None = None
225
- # # weights & biases job url
226
- # wandb_url: str | None = None
227
- # # training file metadata
228
- # training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
229
- # training_file_size: int | None = Field(None, alias="TrainingFileSize")
230
-
231
-
232
- class FinetuneList(BaseModel):
233
- # object type
234
- object: Literal["list"] | None = None
235
- # list of fine-tune job objects
236
- data: List[FinetuneResponse] | None = None
237
-
238
-
239
- class FinetuneListEvents(BaseModel):
240
- # object type
241
- object: Literal["list"] | None = None
242
- # list of fine-tune events
243
- data: List[FinetuneEvent] | None = None
244
-
245
-
246
- class FinetuneDownloadResult(BaseModel):
247
- # object type
248
- object: Literal["local"] | None = None
249
- # fine-tune job id
250
- id: str | None = None
251
- # checkpoint step number
252
- checkpoint_step: int | None = None
253
- # local path filename
254
- filename: str | None = None
255
- # size in bytes
256
- size: int | None = None
File without changes
File without changes
File without changes
File without changes