seekrai 0.5.25__py3-none-any.whl → 0.5.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,7 +20,7 @@ class ChatCompletions:
20
20
  def create(
21
21
  self,
22
22
  *,
23
- messages: List[Dict[str, str]],
23
+ messages: List[Dict[str, Any]],
24
24
  model: str,
25
25
  max_completion_tokens: int | None = None,
26
26
  max_tokens: int | None = 512,
@@ -43,8 +43,11 @@ class ChatCompletions:
43
43
  Method to generate completions based on a given prompt using a specified model.
44
44
 
45
45
  Args:
46
- messages (List[Dict[str, str]]): A list of messages in the format
47
- `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]`
46
+ messages (List[Dict[str, Any]]): A list of messages in the format
47
+ `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]` or
48
+ `[{"role": seekrai.types.chat_completions.MessageRole, "content": PARTS}, ...]`
49
+ where PARTS is a list of content dicts, e.g. {"type": "text", "text": "..."} or
50
+ {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}
48
51
  model (str): The name of the model to query.
49
52
  max_completion_tokens (int, optional): The maximum number of tokens the output can contain.
50
53
  max_tokens (int, optional): The maximum number of tokens to generate.
@@ -141,7 +144,7 @@ class AsyncChatCompletions:
141
144
  async def create(
142
145
  self,
143
146
  *,
144
- messages: List[Dict[str, str]],
147
+ messages: List[Dict[str, Any]],
145
148
  model: str,
146
149
  max_completion_tokens: int | None = None,
147
150
  max_tokens: int | None = 512,
@@ -164,8 +167,11 @@ class AsyncChatCompletions:
164
167
  Async method to generate completions based on a given prompt using a specified model.
165
168
 
166
169
  Args:
167
- messages (List[Dict[str, str]]): A list of messages in the format
168
- `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]`
170
+ messages (List[Dict[str, Any]]): A list of messages in the format
171
+ `[{"role": seekrai.types.chat_completions.MessageRole, "content": TEXT}, ...]` or
172
+ `[{"role": seekrai.types.chat_completions.MessageRole, "content": PARTS}, ...]`
173
+ where PARTS is a list of content dicts, e.g. {"type": "text", "text": "..."} or
174
+ {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,..."}}
169
175
  model (str): The name of the model to query.
170
176
  max_completion_tokens (int, optional): The maximum number of tokens the output can contain.
171
177
  max_tokens (int, optional): The maximum number of tokens to generate.
@@ -40,7 +40,7 @@ class ToolCalls(BaseModel):
40
40
 
41
41
  class ChatCompletionMessage(BaseModel):
42
42
  role: MessageRole
43
- content: str | None = None
43
+ content: str | List[Dict[str, Any]] | None = None
44
44
  # tool_calls: List[ToolCalls] | None = None
45
45
 
46
46
 
seekrai/types/finetune.py CHANGED
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import warnings
3
4
  from datetime import datetime
4
5
  from enum import Enum
5
- from typing import Any, Dict, List, Literal, Optional
6
+ from typing import Any, Callable, Dict, List, Literal, Optional
6
7
 
7
- from pydantic import Field
8
+ from pydantic import Field, field_validator, model_serializer, model_validator
8
9
 
9
10
  from seekrai.types.abstract import BaseModel
10
11
  from seekrai.types.common import (
11
12
  ObjectType,
12
13
  )
14
+ from seekrai.utils._log import log_info
13
15
 
14
16
 
15
17
  class FinetuneJobStatus(str, Enum):
@@ -77,8 +79,28 @@ class FinetuneEventType(str, Enum):
77
79
 
78
80
  class FineTuneType(str, Enum):
79
81
  STANDARD = "STANDARD"
82
+ GRPO = "GRPO" # deprecated
80
83
  PREFERENCE = "PREFERENCE"
81
- GRPO = "GRPO"
84
+ REINFORCEMENT = "REINFORCEMENT"
85
+
86
+
87
+ class GraderType(str, Enum):
88
+ FORMAT_CHECK = "format_check"
89
+ MATH_ACCURACY = "math_accuracy"
90
+ STRING_CHECK = "string_check"
91
+ TEXT_SIMILARITY = "text_similarity"
92
+
93
+
94
+ class StringOperation(str, Enum):
95
+ EQUALS = "equals"
96
+ NOT_EQUALS = "not_equals"
97
+ CONTAINS = "contains"
98
+ CASE_INSENSITIVE_CONTAINS = "case_insensitive_contains"
99
+
100
+
101
+ class TextSimilarityOperation(str, Enum):
102
+ BLEU = "bleu"
103
+ ROUGE = "rouge"
82
104
 
83
105
 
84
106
  class FinetuneEvent(BaseModel):
@@ -93,6 +115,17 @@ class FinetuneEvent(BaseModel):
93
115
  # metrics that we expose
94
116
  loss: float | None = None
95
117
  epoch: float | None = None
118
+ reward: float | None = None
119
+
120
+ @model_serializer(mode="wrap")
121
+ def serialize_model(
122
+ self, handler: Callable[[Any], dict[str, Any]]
123
+ ) -> dict[str, Any]:
124
+ # Remove 'reward' if it's None
125
+ dump_dict = handler(self)
126
+ if dump_dict.get("reward") is None:
127
+ del dump_dict["reward"]
128
+ return dump_dict
96
129
 
97
130
 
98
131
  class LoRAConfig(BaseModel):
@@ -111,6 +144,126 @@ class LoRAConfig(BaseModel):
111
144
  extras: Dict[str, Any] = Field(default_factory=dict)
112
145
 
113
146
 
147
+ class Grader(BaseModel):
148
+ type: GraderType
149
+ weight: float | None = Field(default=None, gt=0.0, le=1.0)
150
+ operation: StringOperation | TextSimilarityOperation | None = Field(default=None)
151
+
152
+ @model_validator(mode="before")
153
+ @classmethod
154
+ def validate_operation(cls, data: Any) -> Any:
155
+ if not isinstance(data, dict):
156
+ return data
157
+
158
+ grader_type = data.get("type")
159
+ operation_value = data.get("operation")
160
+
161
+ if grader_type == GraderType.STRING_CHECK:
162
+ if not operation_value:
163
+ raise ValueError(
164
+ "string_check grader is missing required StringOperation"
165
+ )
166
+ if isinstance(operation_value, str):
167
+ try:
168
+ # Convert to enum to validate it's a valid value
169
+ data["operation"] = StringOperation(operation_value.lower())
170
+ except ValueError:
171
+ raise ValueError(
172
+ f"Invalid operation for string_check grader: "
173
+ f"expected StringOperation, but got type '{type(operation_value).__name__}' with value '{operation_value}'"
174
+ )
175
+ elif grader_type == GraderType.TEXT_SIMILARITY:
176
+ if not operation_value:
177
+ raise ValueError(
178
+ "text_similarity grader is missing required TextSimilarityOperation"
179
+ )
180
+ if isinstance(operation_value, str):
181
+ try:
182
+ data["operation"] = TextSimilarityOperation(operation_value.lower())
183
+ except ValueError:
184
+ raise ValueError(
185
+ f"Invalid operation for text_similarity grader: "
186
+ f"expected TextSimilarityOperation, got type '{type(operation_value).__name__}' with value '{operation_value}'"
187
+ )
188
+
189
+ elif grader_type in (GraderType.FORMAT_CHECK, GraderType.MATH_ACCURACY):
190
+ if operation_value:
191
+ raise ValueError(f"{grader_type} grader cannot have an operation")
192
+ data["operation"] = None
193
+
194
+ return data
195
+
196
+
197
+ class RewardComponents(BaseModel):
198
+ format_reward_weight: float = Field(default=0.1, gt=0.0, le=1.0)
199
+ graders: list[Grader] = Field(min_length=1)
200
+
201
+ @model_validator(mode="after")
202
+ def validate_weights(self) -> "RewardComponents":
203
+ is_format_weight_specified = "format_reward_weight" in self.model_fields_set
204
+
205
+ grader_weights_specified = [
206
+ grader.weight is not None for grader in self.graders
207
+ ]
208
+
209
+ all_graders_have_weights = all(grader_weights_specified)
210
+ some_graders_have_weights = any(grader_weights_specified) and not all(
211
+ grader_weights_specified
212
+ )
213
+ no_graders_have_weights = not any(grader_weights_specified)
214
+
215
+ if some_graders_have_weights:
216
+ raise ValueError(
217
+ "Only some graders have weights specified. Either all graders must have weights specified, or none of them."
218
+ )
219
+
220
+ if all_graders_have_weights and is_format_weight_specified:
221
+ self._validate_weights_sum_to_one()
222
+
223
+ elif all_graders_have_weights and not is_format_weight_specified:
224
+ self._normalize_grader_weights()
225
+
226
+ elif no_graders_have_weights:
227
+ self._initialize_grader_weights()
228
+ self._normalize_grader_weights()
229
+
230
+ return self
231
+
232
+ def _validate_weights_sum_to_one(self) -> None:
233
+ """Validate that format_reward_weight and grader weights sum to 1.0"""
234
+ total_weight = self.format_reward_weight + sum( # type: ignore[operator]
235
+ grader.weight # type: ignore[misc]
236
+ for grader in self.graders
237
+ )
238
+
239
+ if abs(total_weight - 1.0) > 1e-10:
240
+ raise ValueError(
241
+ f"When all weights are explicitly provided, they must sum to 1.0. "
242
+ f"Got format_reward_weight={self.format_reward_weight}, "
243
+ f"graders={self.graders}"
244
+ )
245
+
246
+ def _normalize_grader_weights(self) -> None:
247
+ """Normalize only grader weights to fill (1 - format_reward_weight)"""
248
+ total_grader_weight = sum(grader.weight for grader in self.graders) # type: ignore[misc]
249
+ target_grader_total = 1.0 - self.format_reward_weight
250
+
251
+ # only normalize if weights aren't already properly normalized
252
+ if abs(total_grader_weight - target_grader_total) > 1e-10:
253
+ scale_factor = target_grader_total / total_grader_weight
254
+ for grader in self.graders:
255
+ original_weight = grader.weight
256
+ grader.weight *= scale_factor # type: ignore[operator]
257
+ log_info(
258
+ f"{grader.type}'s weight scaled from {original_weight} to {grader.weight:.2f}"
259
+ )
260
+
261
+ def _initialize_grader_weights(self) -> None:
262
+ """Initialize all grader weights when none are provided"""
263
+ for grader in self.graders:
264
+ grader.weight = 1.0
265
+
266
+
114
267
  class TrainingConfig(BaseModel):
115
268
  # training file ID
116
269
  training_files: List[str]
@@ -136,6 +289,45 @@ class TrainingConfig(BaseModel):
136
289
  fine_tune_type: FineTuneType = FineTuneType.STANDARD
137
290
  # LoRA config
138
291
  lora_config: Optional[LoRAConfig] = None
292
+ # reward_components are REINFORCEMENT-specific
293
+ reward_components: Optional[RewardComponents] = None
294
+
295
+ @model_validator(mode="after")
296
+ def validate_reward_components(self) -> "TrainingConfig":
297
+ # TODO: re-enable the below and make reward_components required for REINFORCEMENT. Disabled for now for backwards-compatibility
298
+ # if (
299
+ # self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
300
+ # and not self.reward_components
301
+ # ):
302
+ # raise ValueError("REINFORCEMENT fine-tuning requires reward components")
303
+ if (
304
+ self.fine_tune_type in (FineTuneType.REINFORCEMENT, FineTuneType.GRPO)
305
+ and not self.reward_components
306
+ ):
307
+ self.reward_components = RewardComponents(
308
+ format_reward_weight=0.1,
309
+ graders=[Grader(type=GraderType.MATH_ACCURACY, weight=0.9)],
310
+ )
311
+ if self.fine_tune_type == FineTuneType.STANDARD and self.reward_components:
312
+ raise ValueError(
313
+ "Reward components are incompatible with standard fine-tuning"
314
+ )
315
+ if self.fine_tune_type == FineTuneType.PREFERENCE and self.reward_components:
316
+ raise ValueError(
317
+ "Reward components are incompatible with preference fine-tuning"
318
+ )
319
+
320
+ return self
321
+
322
+ @field_validator("fine_tune_type")
323
+ def validate_fine_tune_type(cls, v: Any) -> Any:
324
+ if v == FineTuneType.GRPO:
325
+ warnings.warn(
326
+ "FineTuneType.GRPO is deprecated and will be removed in a future version. Use FineTuneType.REINFORCEMENT",
327
+ DeprecationWarning,
328
+ stacklevel=2,
329
+ )
330
+ return v
139
331
 
140
332
 
141
333
  class AcceleratorType(str, Enum):
@@ -172,6 +364,7 @@ class FinetuneResponse(BaseModel):
172
364
  id: str | None = None
173
365
  # fine-tune type
174
366
  fine_tune_type: FineTuneType = FineTuneType.STANDARD
367
+ reward_components: Optional[RewardComponents] = None
175
368
  # training file id
176
369
  training_files: List[str] | None = None
177
370
  # validation file id
@@ -228,6 +421,16 @@ class FinetuneResponse(BaseModel):
228
421
  # training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
229
422
  # training_file_size: int | None = Field(None, alias="TrainingFileSize")
230
423
 
424
+ @model_serializer(mode="wrap")
425
+ def serialize_model(
426
+ self, handler: Callable[[Any], dict[str, Any]]
427
+ ) -> dict[str, Any]:
428
+ # Remove 'reward_components' if it's None
429
+ dump_dict = handler(self)
430
+ if dump_dict.get("reward_components") is None:
431
+ del dump_dict["reward_components"]
432
+ return dump_dict
433
+
231
434
 
232
435
  class FinetuneList(BaseModel):
233
436
  # object type
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: seekrai
3
- Version: 0.5.25
3
+ Version: 0.5.28
4
4
  Summary: Python client for SeekrAI
5
5
  License: Apache-2.0
6
6
  License-File: LICENSE
@@ -15,7 +15,7 @@ seekrai/resources/agents/python_functions.py,sha256=VL1JSsPP5nN1m8I0Ihe3AYlb8TMR
15
15
  seekrai/resources/agents/threads.py,sha256=BwZ2_6wlezsb12PQjEw1fgdJh5S83SPgD6qZQoGvyIM,14544
16
16
  seekrai/resources/alignment.py,sha256=htpwu41NPOb_3aAM6fCilVBcEXlzIYF_req-Cf6OrpI,19627
17
17
  seekrai/resources/chat/__init__.py,sha256=KmtPupgECtEN80NyvcnSmieTAFXhwmVxhMHP0qhspA4,618
18
- seekrai/resources/chat/completions.py,sha256=-nFk5aL1ejZ9WUi_ksDHqCnEnT2CjPX-RaigT3N35gA,12317
18
+ seekrai/resources/chat/completions.py,sha256=DN3xSzDFJFOa8iU-E38NNFb5Ax0XLv4U5QatQiMS9Ig,12887
19
19
  seekrai/resources/completions.py,sha256=JhTN_lW2mblfHHONFmPC7QZei3wo5vx6GliMs9FkbOY,8452
20
20
  seekrai/resources/deployments.py,sha256=DY7IN7QgqDduCHGNuHENSVwrE5PXFL88jWgh8SES7Qk,5970
21
21
  seekrai/resources/embeddings.py,sha256=7G-VisYrT9J35-hcKB8cXhs8BSi93IfveQKfVSC7diA,2585
@@ -50,7 +50,7 @@ seekrai/types/agents/tools/schemas/web_search_env.py,sha256=R3fGXV43ZacbgoSy-M43
50
50
  seekrai/types/agents/tools/tool.py,sha256=XdbYzSn0Rr-m9s34Aa1NDQ-35Kpzk4rcaXY3dKrg-Ys,514
51
51
  seekrai/types/agents/tools/tool_types.py,sha256=1tF_kE6Z_zzuZpOAK1HrHsHkXFPEoK0PdYv-pbTLfkY,360
52
52
  seekrai/types/alignment.py,sha256=nWcc4kQLs40-T0_HC3MnGkLd-StwBvwCXQrjUVJ5dEI,2973
53
- seekrai/types/chat_completions.py,sha256=Z7H1MkMgb4O0O5LDMKotQqhjGVCYk5eBeZ8n--RJpf8,3736
53
+ seekrai/types/chat_completions.py,sha256=-ganbaf5wDlNIj6zWGjcmkCwTIoCUthyVO9vuBcQiwk,3759
54
54
  seekrai/types/common.py,sha256=YI1pE-i_lDLU2o6FjoINdIhPXsV9lUl2MeAg2aRtT-M,2062
55
55
  seekrai/types/completions.py,sha256=lm9AFdZR3Xg5AHPkV-qETHikkwMJmkHrLGr5GG-YR-M,2171
56
56
  seekrai/types/deployments.py,sha256=a0zew1DuB9vPQXcBT2R4Tdn_8z5qleh6V6i4T4xyYZo,1798
@@ -59,7 +59,7 @@ seekrai/types/enums.py,sha256=sQ1CW-ctbhpV2jM1cEAEy7ZUdzZa0IC85YvycjvudHE,633
59
59
  seekrai/types/error.py,sha256=uTKISs9aRC4_6zwirtNkanxepN8KY-SqCq0kNbfZylQ,370
60
60
  seekrai/types/explainability.py,sha256=Ih-8hCm5r22EMMtr83cDy8vePo7_Ik7UdUcXhsj5Zm0,835
61
61
  seekrai/types/files.py,sha256=kOy4s8D4tlsenyWmiiEyAS0jDAdxMScBu5j1GwQCf3E,2808
62
- seekrai/types/finetune.py,sha256=VHAzIvU-B99TEVsuwl0pf8TODFOMYKT1dxr0kRX4Z4o,7218
62
+ seekrai/types/finetune.py,sha256=-dRSjRqJVu2-dEfykOJYTuuzt6Ok1nx91gJzQ_WAqEU,15341
63
63
  seekrai/types/images.py,sha256=Fusj8OhVYFsT8kz636lRGGivLbPXo_ZNgakKwmzJi3U,914
64
64
  seekrai/types/ingestion.py,sha256=uUdKOR4xqSfAXWQOR1UOltSlOnuyAwKVA1Q2a6Yslk8,919
65
65
  seekrai/types/models.py,sha256=9Z0nvLdlAfpF8mNRW5-IqBdDHoE-3qQ5przmIDJgwLo,1345
@@ -72,8 +72,8 @@ seekrai/utils/api_helpers.py,sha256=0Y8BblNIr9h_R12zdmhkxgTlxgoRkbq84QNi4nNWGu8,
72
72
  seekrai/utils/files.py,sha256=7ixn_hgV-6pEhYqLyOp-EN0o8c1CzUwJzX9n3PQ5oqo,7164
73
73
  seekrai/utils/tools.py,sha256=jgJTL-dOIouDbEJLdQpQfpXhqaz_poQYS52adyUtBjo,1781
74
74
  seekrai/version.py,sha256=q6iGQVFor8zXiPP5F-3vy9TndOxKv5JXbaNJ2kdOQws,125
75
- seekrai-0.5.25.dist-info/METADATA,sha256=SWvIVZI3Tb-_0QVntTa7DoRoqHYPsXOSCfhEl5fSn4M,4788
76
- seekrai-0.5.25.dist-info/WHEEL,sha256=3ny-bZhpXrU6vSQ1UPG34FoxZBp3lVcvK0LkgUz6VLk,88
77
- seekrai-0.5.25.dist-info/entry_points.txt,sha256=N49yOEGi1sK7Xr13F_rkkcOxQ88suyiMoOmRhUHTZ_U,48
78
- seekrai-0.5.25.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
79
- seekrai-0.5.25.dist-info/RECORD,,
75
+ seekrai-0.5.28.dist-info/METADATA,sha256=LQLnopWdiNd6l3OJP92tv-DzgJWOa_jw_0lbXooI2JE,4788
76
+ seekrai-0.5.28.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
77
+ seekrai-0.5.28.dist-info/entry_points.txt,sha256=N49yOEGi1sK7Xr13F_rkkcOxQ88suyiMoOmRhUHTZ_U,48
78
+ seekrai-0.5.28.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
79
+ seekrai-0.5.28.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.3.0
2
+ Generator: poetry-core 2.3.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any